@@ -21,6 +21,13 @@ limitations under the License.
2121#define EPSILON 1e-5
2222#define BLOCK_SIZE 32
2323#define TILE (SIZE, STRIDE ) ((((SIZE)-1 ) / (STRIDE)) + 1 )
24+ #ifdef __HIP_PLATFORM_AMD__
25+ #define __SHFL_DOWN (a, b ) __shfl_down (a, b)
26+ #define __SHFL_XOR (a, b ) __shfl_xor (a, b)
27+ #else
28+ #define __SHFL_DOWN (a, b ) __shfl_down_sync(0xffffffff , a, b)
29+ #define __SHFL_XOR (a, b ) __shfl_xor_sync(0xffffffff , a, b)
30+ #endif
2431
2532template <int warp_count, int load_count>
2633__global__ void CovarianceReductionKernel (
@@ -82,13 +89,11 @@ __global__ void CovarianceReductionKernel(
8289
8390 for (int i = 0 ; i < MATRIX_COMPONENT_COUNT; i++) {
8491 float matrix_component = matrix[i];
85-
86- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 16 );
87- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 8 );
88- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 4 );
89- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 2 );
90- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 1 );
91-
92+ matrix_component += __SHFL_DOWN (matrix_component, 16 );
93+ matrix_component += __SHFL_DOWN (matrix_component, 8 );
94+ matrix_component += __SHFL_DOWN (matrix_component, 4 );
95+ matrix_component += __SHFL_DOWN (matrix_component, 2 );
96+ matrix_component += __SHFL_DOWN (matrix_component, 1 );
9297 if (lane_index == 0 ) {
9398 s_matrix_component[warp_index] = matrix_component;
9499 }
@@ -97,23 +102,21 @@ __global__ void CovarianceReductionKernel(
97102
98103 if (warp_index == 0 ) {
99104 matrix_component = s_matrix_component[lane_index];
100-
101105 if (warp_count >= 32 ) {
102- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 16 );
106+ matrix_component += __SHFL_DOWN ( matrix_component, 16 );
103107 }
104108 if (warp_count >= 16 ) {
105- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 8 );
109+ matrix_component += __SHFL_DOWN ( matrix_component, 8 );
106110 }
107111 if (warp_count >= 8 ) {
108- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 4 );
112+ matrix_component += __SHFL_DOWN ( matrix_component, 4 );
109113 }
110114 if (warp_count >= 4 ) {
111- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 2 );
115+ matrix_component += __SHFL_DOWN ( matrix_component, 2 );
112116 }
113117 if (warp_count >= 2 ) {
114- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 1 );
118+ matrix_component += __SHFL_DOWN ( matrix_component, 1 );
115119 }
116-
117120 if (lane_index == 0 ) {
118121 g_batch_matrices[matrix_offset + i] = matrix_component;
119122 }
@@ -156,13 +159,11 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g
156159 matrix_component += g_batch_matrices[(matrix_offset + matrix_index) * GMM_COMPONENT_COUNT + index];
157160 }
158161 }
159-
160- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 16 );
161- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 8 );
162- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 4 );
163- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 2 );
164- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 1 );
165-
162+ matrix_component += __SHFL_DOWN (matrix_component, 16 );
163+ matrix_component += __SHFL_DOWN (matrix_component, 8 );
164+ matrix_component += __SHFL_DOWN (matrix_component, 4 );
165+ matrix_component += __SHFL_DOWN (matrix_component, 2 );
166+ matrix_component += __SHFL_DOWN (matrix_component, 1 );
166167 if (lane_index == 0 ) {
167168 s_matrix_component[warp_index] = matrix_component;
168169 }
@@ -171,23 +172,21 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g
171172
172173 if (warp_index == 0 ) {
173174 matrix_component = s_matrix_component[lane_index];
174-
175175 if (warp_count >= 32 ) {
176- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 16 );
176+ matrix_component += __SHFL_DOWN ( matrix_component, 16 );
177177 }
178178 if (warp_count >= 16 ) {
179- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 8 );
179+ matrix_component += __SHFL_DOWN ( matrix_component, 8 );
180180 }
181181 if (warp_count >= 8 ) {
182- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 4 );
182+ matrix_component += __SHFL_DOWN ( matrix_component, 4 );
183183 }
184184 if (warp_count >= 4 ) {
185- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 2 );
185+ matrix_component += __SHFL_DOWN ( matrix_component, 2 );
186186 }
187187 if (warp_count >= 2 ) {
188- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 1 );
188+ matrix_component += __SHFL_DOWN ( matrix_component, 1 );
189189 }
190-
191190 if (lane_index == 0 ) {
192191 float constant = i == 0 ? 0 .0f : s_gmm[i] * s_gmm[j];
193192
@@ -261,13 +260,11 @@ __global__ void GMMFindSplit(GMMSplit_t* gmmSplit, int gmmK, float* gmm) {
261260 }
262261
263262 float max_value = eigenvalue;
264-
265- max_value = max (max_value, __shfl_xor_sync (0xffffffff , max_value, 16 ));
266- max_value = max (max_value, __shfl_xor_sync (0xffffffff , max_value, 8 ));
267- max_value = max (max_value, __shfl_xor_sync (0xffffffff , max_value, 4 ));
268- max_value = max (max_value, __shfl_xor_sync (0xffffffff , max_value, 2 ));
269- max_value = max (max_value, __shfl_xor_sync (0xffffffff , max_value, 1 ));
270-
263+ max_value = max (max_value, __SHFL_XOR (max_value, 16 ));
264+ max_value = max (max_value, __SHFL_XOR (max_value, 8 ));
265+ max_value = max (max_value, __SHFL_XOR (max_value, 4 ));
266+ max_value = max (max_value, __SHFL_XOR (max_value, 2 ));
267+ max_value = max (max_value, __SHFL_XOR (max_value, 1 ));
271268 if (max_value == eigenvalue) {
272269 GMMSplit_t split;
273270
@@ -347,12 +344,11 @@ __global__ void GMMcommonTerm(float* g_gmm) {
347344 float gmm_n = threadIdx .x < MIXTURE_SIZE ? g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT] : 0 .0f ;
348345
349346 float sum = gmm_n;
350-
351- sum += __shfl_xor_sync (0xffffffff , sum, 1 );
352- sum += __shfl_xor_sync (0xffffffff , sum, 2 );
353- sum += __shfl_xor_sync (0xffffffff , sum, 4 );
354- sum += __shfl_xor_sync (0xffffffff , sum, 8 );
355- sum += __shfl_xor_sync (0xffffffff , sum, 16 );
347+ sum += __SHFL_XOR (sum, 1 );
348+ sum += __SHFL_XOR (sum, 2 );
349+ sum += __SHFL_XOR (sum, 4 );
350+ sum += __SHFL_XOR (sum, 8 );
351+ sum += __SHFL_XOR (sum, 16 );
356352
357353 if (threadIdx .x < MIXTURE_SIZE) {
358354 float det = g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT + MATRIX_COMPONENT_COUNT] + EPSILON;
@@ -446,13 +442,14 @@ void GMMInitialize(
446442 for (unsigned int k = MIXTURE_COUNT; k < gmm_N; k += MIXTURE_COUNT) {
447443 for (unsigned int i = 0 ; i < k; ++i) {
448444 CovarianceReductionKernel<WARPS, LOAD>
449- <<<{ block_count, 1 , batch_count} , BLOCK>>> (i, image, alpha, block_gmm_scratch, element_count);
445+ <<<dim3 ( block_count, 1 , batch_count) , BLOCK>>> (i, image, alpha, block_gmm_scratch, element_count);
450446 }
451447
452- CovarianceFinalizationKernel<WARPS, false ><<<{ k, 1 , batch_count} , BLOCK>>> (block_gmm_scratch, gmm, block_count);
448+ CovarianceFinalizationKernel<WARPS, false ><<<dim3 ( k, 1 , batch_count) , BLOCK>>> (block_gmm_scratch, gmm, block_count);
453449
454- GMMFindSplit<<<{1 , 1 , batch_count}, dim3 (BLOCK_SIZE, MIXTURE_COUNT)>>> (gmm_split_scratch, k / MIXTURE_COUNT, gmm);
455- GMMDoSplit<<<{TILE (element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1 , batch_count}, BLOCK_SIZE>>> (
450+ GMMFindSplit<<<dim3 (1 , 1 , batch_count), dim3 (BLOCK_SIZE, MIXTURE_COUNT)>>> (
451+ gmm_split_scratch, k / MIXTURE_COUNT, gmm);
452+ GMMDoSplit<<<dim3 (TILE(element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1 , batch_count), BLOCK_SIZE>>> (
456453 gmm_split_scratch, (k / MIXTURE_COUNT) << 4 , image, alpha, element_count);
457454 }
458455}
@@ -472,12 +469,13 @@ void GMMUpdate(
472469
473470 for (unsigned int i = 0 ; i < gmm_N; ++i) {
474471 CovarianceReductionKernel<WARPS, LOAD>
475- <<<{ block_count, 1 , batch_count} , BLOCK>>> (i, image, alpha, block_gmm_scratch, element_count);
472+ <<<dim3 ( block_count, 1 , batch_count) , BLOCK>>> (i, image, alpha, block_gmm_scratch, element_count);
476473 }
477474
478- CovarianceFinalizationKernel<WARPS, true ><<<{gmm_N, 1 , batch_count}, BLOCK>>> (block_gmm_scratch, gmm, block_count);
475+ CovarianceFinalizationKernel<WARPS, true >
476+ <<<dim3 (gmm_N, 1 , batch_count), BLOCK>>> (block_gmm_scratch, gmm, block_count);
479477
480- GMMcommonTerm<<<{ 1 , 1 , batch_count} , dim3 (BLOCK_SIZE, MIXTURE_COUNT)>>> (gmm);
478+ GMMcommonTerm<<<dim3 ( 1 , 1 , batch_count) , dim3 (BLOCK_SIZE, MIXTURE_COUNT)>>> (gmm);
481479}
482480
483481void GMMDataTerm (
0 commit comments