33 #include <cub/cub.cuh>
34 #include <thrust/copy.h>
57 const WorkMoverT mover)
69 typedef WorkUnitT WorkUnit;
71 typedef cub::BlockReduce<uint32,BLOCKDIM> BlockReduce;
72 typedef cub::BlockScan<uint32,BLOCKDIM> BlockBitScan;
75 const uint32 queue_capacity = grid_threads * n_tile_grids;
78 uint32 work_queue_id0 = 0u;
79 WorkUnit* work_queue0 = context.m_work_queue;
80 WorkUnit* work_queue1 = context.m_work_queue + queue_capacity;
81 volatile uint32* work_queue_size0 = context.m_work_queue_size;
82 volatile uint32* work_queue_size1 = context.m_work_queue_size + 1;
84 __shared__
typename BlockBitScan::TempStorage scan_smem_storage;
85 __shared__
typename BlockReduce::TempStorage reduce_smem_storage;
87 const uint32 COND_THRESHOLD = 10000000;
95 #define SINGLE_CTA_SCAN
96 #if !defined(SINGLE_CTA_SCAN)
97 __shared__
volatile uint32 previous_done;
99 uint8* continuations = context.m_continuations;
103 #if defined(QUEUE_GATHER)
104 uint32* source_ids = context.m_source_ids;
107 volatile uint32* partials = context.m_partials;
108 volatile uint32* prefixes = context.m_prefixes;
110 const bool is_thread0 = (threadIdx.x == 0);
120 context.m_syncblocks.enact();
122 uint32 in_work_queue_size = *work_queue_size0;
125 for (
uint32 i = 0; i < n_tile_grids; ++i)
127 const uint32 work_begin = grid_threads * i;
131 if ((work_begin <= in_work_queue_size) &&
132 (work_begin + grid_threads > in_work_queue_size) &&
133 (stream_begin < stream_end))
135 const uint32 n_loaded =
nvbio::min( stream_end - stream_begin, grid_threads - (in_work_queue_size - work_begin) );
138 if ((work_id >= in_work_queue_size) &&
139 (work_id - in_work_queue_size < n_loaded))
140 stream.get( stream_begin + work_id - in_work_queue_size, work_queue0 + work_id, make_uint2( work_id, work_queue_id0 ) );
142 in_work_queue_size += n_loaded;
143 stream_begin += n_loaded;
148 if (in_work_queue_size == 0)
152 NVBIO_CUDA_DEBUG_ASSERT( (__syncthreads_and(
true) ==
true),
"error: not all threads are active at block: %u\n", blockIdx.x );
155 const uint32 n_active_tile_grids = (in_work_queue_size + grid_threads-1) / grid_threads;
156 const uint32 n_active_tiles = n_active_tile_grids * gridDim.x;
158 #if defined(SINGLE_CTA_SCAN)
165 for (
uint32 i = 0; i < n_active_tile_grids; ++i)
171 const uint32 tile_idx = blockIdx.x + gridDim.x * i;
174 bool has_continuation =
false;
177 if (work_id < in_work_queue_size)
178 has_continuation = work_queue0[
work_id ].run( stream );
180 continuations[
work_id ] = has_continuation ? 1u : 0u;
183 const uint32 block_aggregate = BlockReduce( reduce_smem_storage ).Sum( has_continuation ? 1 : 0 );
188 partials[tile_idx] = block_aggregate;
191 conditions[tile_idx].set( iteration*2 + PARTIAL_READY );
204 for (
uint32 tile_begin = 0; tile_begin < n_active_tiles; tile_begin +=
BLOCKDIM)
206 const uint32 tile_x = tile_begin + threadIdx.x;
207 const bool tile_valid = (tile_x < n_active_tiles);
212 if (!conditions[tile_x].wait( iteration*2 + PARTIAL_READY, COND_THRESHOLD ))
213 printf(
"PARTIAL_READY test failed for tile %u/%u at iteration %u\n", tile_x, n_active_tiles, iteration);
222 BlockBitScan( scan_smem_storage ).ExclusiveSum( partial, partial_scan, partial_aggregate );
227 prefixes[tile_x] = carry + partial_scan;
228 conditions[tile_x].set( iteration*2 + PREFIX_READY );
231 carry += partial_aggregate;
235 *work_queue_size1 = carry;
241 for (
uint32 i = 0; i < n_active_tile_grids; ++i)
247 const uint32 tile_idx = blockIdx.x + gridDim.x * i;
251 if (!conditions[tile_idx].wait( iteration*2 + PREFIX_READY, COND_THRESHOLD ))
252 printf(
"PREFIX_READY test failed for tile %u/%u at iteration %u\n", tile_idx, n_active_tiles, iteration);
254 const uint32 prefix = prefixes[tile_idx];
262 BlockBitScan( scan_smem_storage ).ExclusiveSum( has_continuation, block_scan, block_aggregate );
265 if (has_continuation)
267 #if defined(QUEUE_GATHER)
268 source_ids[ prefix + block_scan ] =
work_id;
273 make_uint2( work_id, work_queue_id0 ), &work_queue0[ work_id ],
274 make_uint2( prefix + block_scan, work_queue_id0 ? 0u : 1u ), &work_queue1[ prefix + block_scan ] );
278 #else // !SINGLE_CTA_SCAN
286 for (
uint32 i = 0; i < n_tile_grids; ++i)
291 const uint32 tile_idx = blockIdx.x + gridDim.x * i;
295 bool has_continuation =
false;
298 if (work_id < in_work_queue_size)
299 has_continuation = work_queue0[
work_id ].run( stream );
305 BlockBitScan( scan_smem_storage ).ExclusiveSum( has_continuation ? 1 : 0, block_scan, block_aggregate );
309 partials[tile_idx] = block_aggregate;
312 if (is_thread0 && tile_idx)
313 previous_done = conditions[tile_idx-1].test( iteration*2 + PREFIX_READY );
324 prefixes[0] = block_aggregate;
326 #if 0 // simple chaining
331 if (!conditions[tile_idx-1].wait( iteration*2 + PREFIX_READY, COND_THRESHOLD ))
332 printf(
"PREFIX_READY test failed for tile %u/%u at iteration %u\n", tile_idx-1, n_active_tiles, iteration);
334 prefix = prefixes[tile_idx-1];
338 prefixes[tile_idx] = prefix + block_aggregate;
341 #else // adaptive lookback
342 else if (previous_done)
344 prefix = prefixes[tile_idx-1];
348 prefixes[tile_idx] = prefix + block_aggregate;
354 conditions[tile_idx].set( iteration*2 + PARTIAL_READY );
356 int32 last_tile = tile_idx;
357 int32 prefix_tile = tile_idx;
375 if (prefix_tile + threadIdx.x < last_tile)
377 if (conditions[ prefix_tile + threadIdx.x ].test( iteration*2 + PREFIX_READY ))
378 previous_done = prefix_tile + threadIdx.x;
385 prefix_tile = previous_done;
390 if (prefix_tile + threadIdx.x < last_tile)
392 if (previous_done && threadIdx.x == 0)
395 partial = prefixes[ prefix_tile ];
400 if (!conditions[ prefix_tile + threadIdx.x ].wait( iteration*2 + PARTIAL_READY, COND_THRESHOLD ))
401 printf(
"PARTIAL_READY test failed for tile %u at tile %u/%u at iteration %u\n", prefix_tile + threadIdx.x, tile_idx-1, n_active_tiles, iteration);
403 partial = partials[ prefix_tile + threadIdx.x ];
408 prefix += BlockReduce( reduce_smem_storage ).Sum( partial );
410 last_tile = prefix_tile;
412 while (prefix_tile && !previous_done);
417 prefixes[tile_idx] = prefix + block_aggregate;
424 previous_done = prefix;
426 prefix = previous_done;
429 if (tile_idx == n_active_tiles-1 && is_thread0)
430 *work_queue_size1 = prefix + block_aggregate;
433 if (has_continuation)
435 #if defined(QUEUE_GATHER)
436 source_ids[ prefix + block_scan ] =
work_id;
441 make_uint2( work_id, work_queue_id0 ), &work_queue0[ work_id ],
442 make_uint2( prefix + block_scan, work_queue_id0 ? 0u : 1u ), &work_queue1[ prefix + block_scan ] );
448 conditions[tile_idx].set( iteration*2 + PREFIX_READY );
454 if (!conditions[n_active_tiles-1].wait( iteration*2 + PREFIX_READY, COND_THRESHOLD ))
455 printf(
"PREFIX_READY test failed for last tile (%u) at iteration %u\n", n_active_tiles-1, iteration);
456 #endif // !SINGLE_CTA_SCAN
460 WorkUnit* tmp = work_queue0;
461 work_queue0 = work_queue1;
465 volatile uint32* tmp = work_queue_size0;
466 work_queue_size0 = work_queue_size1;
467 work_queue_size1 = tmp;
469 work_queue_id0 = work_queue_id0 ? 0u : 1u;
474 #if defined(QUEUE_GATHER)
477 context.m_syncblocks.enact();
480 const uint32 out_grid_size = *work_queue_size0;
481 const uint32 n_active_tile_grids = (out_grid_size + grid_threads-1) / grid_threads;
483 for (
uint32 i = 0; i < n_active_tile_grids; ++i)
488 if (work_id < out_grid_size)
495 make_uint2( src_id, work_queue_id0 ? 0u : 1u ), &work_queue1[ src_id ],
496 make_uint2( work_id, work_queue_id0 ? 1u : 0u ), &work_queue0[ work_id ] );
500 #endif // QUEUE_GATHER
513 template <
typename WorkStream,
typename WorkMover>
522 const uint32 n_tile_grids = m_capacity / grid_threads;
523 m_condition_set.resize( n_blocks*n_tile_grids );
524 m_partials.resize( n_blocks*n_tile_grids );
525 m_prefixes.resize( n_blocks*n_tile_grids );
526 m_continuations.resize( grid_threads*n_tile_grids );
527 m_source_ids.resize( grid_threads*n_tile_grids );
528 m_work_queue.resize( grid_threads*n_tile_grids * 2 );
529 m_work_queue_size.resize( 2 );
530 m_syncblocks.clear();
532 m_work_queue_size[0] = 0;
535 wq::work_queue_kernel<BLOCKDIM,WorkUnit,WorkStream,WorkMover> <<<n_blocks,BLOCKDIM>>>( n_tile_grids, get_context(),
stream, mover );