You can do this using shared memory (which will store it βon the chipβ). I'm not sure I know how to do this using strictly registers without deconstructing the BlockRadixSort object.
Here is an example of code that uses shared memory to store raw data for sorting and final sorted results. This pattern is basically set up for one data item for each stream, as that seems to be what you are asking for. It is easy to extend it to several elements per stream, and I put most of the plumbing in place to do this, with the exception of data synthesis and debug printouts:
#include <cub/cub.cuh> #include <stdio.h> #define nTPB 32 #define ELEMS_PER_THREAD 1 // Block-sorting CUDA kernel (nTPB threads each owning ELEMS_PER THREAD integers) __global__ void BlockSortKernel() { __shared__ int my_val[nTPB*ELEMS_PER_THREAD]; using namespace cub; // Specialize BlockRadixSort collective types typedef BlockRadixSort<int, nTPB, ELEMS_PER_THREAD> my_block_sort; // Allocate shared memory for collectives __shared__ typename my_block_sort::TempStorage sort_temp_stg; // need to extend synthetic data for ELEMS_PER_THREAD > 1 my_val[threadIdx.x*ELEMS_PER_THREAD] = (threadIdx.x + 5)%nTPB; // synth data __syncthreads(); printf("thread %d data = %d\n", threadIdx.x, my_val[threadIdx.x*ELEMS_PER_THREAD]); // Collectively sort the keys my_block_sort(sort_temp_stg).Sort(*static_cast<int(*)[ELEMS_PER_THREAD]>(static_cast<void*>(my_val+(threadIdx.x*ELEMS_PER_THREAD)))); __syncthreads(); printf("thread %d sorted data = %d\n", threadIdx.x, my_val[threadIdx.x*ELEMS_PER_THREAD]); } int main(){ BlockSortKernel<<<1,nTPB>>>(); cudaDeviceSynchronize(); }
This seems to work correctly for me, in this case I used RHEL 5.5 / gcc 4.1.2, CUDA 6.0 RC and CUB v1.2.0 (which is pretty recent).
A strange / ugly static casting is necessary, as far as I can tell, because CUB Sort expects a link to an array of length equal to the ITEMS_PER_THREAD setting (i.e. ELEMS_PER_THREAD ):
__device__ __forceinline__ void Sort( Key (&keys)[ITEMS_PER_THREAD], int begin_bit = 0, int end_bit = sizeof(Key) * 8) { ...
source share