@@ -70,10 +70,12 @@ template <typename argTy, typename outTy>
7070struct TypePairSupportDataForSumAccumulation
7171{
7272 static constexpr bool is_defined = std::disjunction<
73+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, bool >,
7374 td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int32_t >,
7475 td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int64_t >,
7576
7677 // input int8_t
78+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int8_t >,
7779 td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int32_t >,
7880 td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int64_t >,
7981
@@ -130,6 +132,10 @@ struct TypePairSupportDataForSumAccumulation
130132 td_ns::NotDefinedEntry>::is_defined;
131133};
132134
135+ template <typename T>
136+ using CumSumScanOpT = std::
137+ conditional_t <std::is_same_v<T, bool >, sycl::logical_or<T>, sycl::plus<T>>;
138+
133139template <typename fnT, typename srcTy, typename dstTy>
134140struct CumSum1DContigFactory
135141{
@@ -138,7 +144,7 @@ struct CumSum1DContigFactory
138144 if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
139145 dstTy>::is_defined)
140146 {
141- using ScanOpT = sycl::plus <dstTy>;
147+ using ScanOpT = CumSumScanOpT <dstTy>;
142148 constexpr bool include_initial = false ;
143149 if constexpr (std::is_same_v<srcTy, dstTy>) {
144150 using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -171,7 +177,7 @@ struct CumSum1DIncludeInitialContigFactory
171177 if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
172178 dstTy>::is_defined)
173179 {
174- using ScanOpT = sycl::plus <dstTy>;
180+ using ScanOpT = CumSumScanOpT <dstTy>;
175181 constexpr bool include_initial = true ;
176182 if constexpr (std::is_same_v<srcTy, dstTy>) {
177183 using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -204,7 +210,7 @@ struct CumSumStridedFactory
204210 if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
205211 dstTy>::is_defined)
206212 {
207- using ScanOpT = sycl::plus <dstTy>;
213+ using ScanOpT = CumSumScanOpT <dstTy>;
208214 constexpr bool include_initial = false ;
209215 if constexpr (std::is_same_v<srcTy, dstTy>) {
210216 using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -237,7 +243,7 @@ struct CumSumIncludeInitialStridedFactory
237243 if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
238244 dstTy>::is_defined)
239245 {
240- using ScanOpT = sycl::plus <dstTy>;
246+ using ScanOpT = CumSumScanOpT <dstTy>;
241247 constexpr bool include_initial = true ;
242248 if constexpr (std::is_same_v<srcTy, dstTy>) {
243249 using dpctl::tensor::kernels::accumulators::NoOpTransformer;
0 commit comments