2323// ===---------------------------------------------------------------------===//
2424
2525#pragma once
26+ #include < cstddef>
2627#include < cstdint>
2728#include < limits>
2829#include < sycl/sycl.hpp>
@@ -42,6 +43,7 @@ namespace kernels
4243namespace indexing
4344{
4445
46+ using dpctl::tensor::ssize_t ;
4547using namespace dpctl ::tensor::offset_utils;
4648
4749template <typename OrthogIndexerT,
@@ -55,7 +57,7 @@ struct MaskedExtractStridedFunctor
5557 MaskedExtractStridedFunctor (const dataT *src_data_p,
5658 const indT *cumsum_data_p,
5759 dataT *dst_data_p,
58- size_t masked_iter_size,
60+ std:: size_t masked_iter_size,
5961 const OrthogIndexerT &orthog_src_dst_indexer_,
6062 const MaskedSrcIndexerT &masked_src_indexer_,
6163 const MaskedDstIndexerT &masked_dst_indexer_,
@@ -81,7 +83,7 @@ struct MaskedExtractStridedFunctor
8183
8284 const std::size_t max_offset = masked_nelems + 1 ;
8385 for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
84- const size_t offset = masked_block_start + i;
86+ const std:: size_t offset = masked_block_start + i;
8587 lacc[i] = (offset == 0 ) ? indT (0 )
8688 : (offset < max_offset) ? cumsum[offset - 1 ]
8789 : cumsum[masked_nelems - 1 ] + 1 ;
@@ -99,9 +101,10 @@ struct MaskedExtractStridedFunctor
99101 if (mask_set && (masked_i < masked_nelems)) {
100102 const auto &orthog_offsets = orthog_src_dst_indexer (orthog_i);
101103
102- const size_t total_src_offset = masked_src_indexer (masked_i) +
103- orthog_offsets.get_first_offset ();
104- const size_t total_dst_offset =
104+ const std::size_t total_src_offset =
105+ masked_src_indexer (masked_i) +
106+ orthog_offsets.get_first_offset ();
107+ const std::size_t total_dst_offset =
105108 masked_dst_indexer (current_running_count - 1 ) +
106109 orthog_offsets.get_second_offset ();
107110
@@ -113,7 +116,7 @@ struct MaskedExtractStridedFunctor
113116 const dataT *src = nullptr ;
114117 const indT *cumsum = nullptr ;
115118 dataT *dst = nullptr ;
116- const size_t masked_nelems = 0 ;
119+ const std:: size_t masked_nelems = 0 ;
117120 // has nd, shape, src_strides, dst_strides for
118121 // dimensions that ARE NOT masked
119122 const OrthogIndexerT orthog_src_dst_indexer;
@@ -136,7 +139,7 @@ struct MaskedPlaceStridedFunctor
136139 MaskedPlaceStridedFunctor (dataT *dst_data_p,
137140 const indT *cumsum_data_p,
138141 const dataT *rhs_data_p,
139- size_t masked_iter_size,
142+ std:: size_t masked_iter_size,
140143 const OrthogIndexerT &orthog_dst_rhs_indexer_,
141144 const MaskedDstIndexerT &masked_dst_indexer_,
142145 const MaskedRhsIndexerT &masked_rhs_indexer_,
@@ -157,12 +160,12 @@ struct MaskedPlaceStridedFunctor
157160 const std::uint32_t l_i = ndit.get_local_id (1 );
158161 const std::uint32_t lws = ndit.get_local_range (1 );
159162
160- const size_t masked_i = ndit.get_global_id (1 );
161- const size_t masked_block_start = masked_i - l_i;
163+ const std:: size_t masked_i = ndit.get_global_id (1 );
164+ const std:: size_t masked_block_start = masked_i - l_i;
162165
163166 const std::size_t max_offset = masked_nelems + 1 ;
164167 for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
165- const size_t offset = masked_block_start + i;
168+ const std:: size_t offset = masked_block_start + i;
166169 lacc[i] = (offset == 0 ) ? indT (0 )
167170 : (offset < max_offset) ? cumsum[offset - 1 ]
168171 : cumsum[masked_nelems - 1 ] + 1 ;
@@ -180,9 +183,10 @@ struct MaskedPlaceStridedFunctor
180183 if (mask_set && (masked_i < masked_nelems)) {
181184 const auto &orthog_offsets = orthog_dst_rhs_indexer (orthog_i);
182185
183- const size_t total_dst_offset = masked_dst_indexer (masked_i) +
184- orthog_offsets.get_first_offset ();
185- const size_t total_rhs_offset =
186+ const std::size_t total_dst_offset =
187+ masked_dst_indexer (masked_i) +
188+ orthog_offsets.get_first_offset ();
189+ const std::size_t total_rhs_offset =
186190 masked_rhs_indexer (current_running_count - 1 ) +
187191 orthog_offsets.get_second_offset ();
188192
@@ -194,7 +198,7 @@ struct MaskedPlaceStridedFunctor
194198 dataT *dst = nullptr ;
195199 const indT *cumsum = nullptr ;
196200 const dataT *rhs = nullptr ;
197- const size_t masked_nelems = 0 ;
201+ const std:: size_t masked_nelems = 0 ;
198202 // has nd, shape, dst_strides, rhs_strides for
199203 // dimensions that ARE NOT masked
200204 const OrthogIndexerT orthog_dst_rhs_indexer;
@@ -450,8 +454,8 @@ sycl::event masked_extract_some_slices_strided_impl(
450454
451455 const std::size_t lws = get_lws (masked_extent);
452456
453- const size_t n_groups = ((masked_extent + lws - 1 ) / lws);
454- const size_t orthog_extent = static_cast <size_t >(orthog_nelems);
457+ const std:: size_t n_groups = ((masked_extent + lws - 1 ) / lws);
458+ const std:: size_t orthog_extent = static_cast <std:: size_t >(orthog_nelems);
455459
456460 sycl::range<2 > gRange {orthog_extent, n_groups * lws};
457461 sycl::range<2 > lRange{1 , lws};
@@ -809,7 +813,7 @@ sycl::event non_zero_indexes_impl(sycl::queue &exec_q,
809813 const std::size_t masked_block_start = group_i * lws;
810814
811815 for (std::uint32_t i = l_i; i < lacc.size (); i += lws) {
812- const size_t offset = masked_block_start + i;
816+ const std:: size_t offset = masked_block_start + i;
813817 lacc[i] = (offset == 0 ) ? indT1 (0 )
814818 : (offset - 1 < masked_extent)
815819 ? cumsum_data[offset - 1 ]
0 commit comments