@@ -120,6 +120,7 @@ struct TypePairSupportDataForProductReductionTemps
120120{
121121
122122 static constexpr bool is_defined = std::disjunction<
123+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, bool >,
123124 td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int8_t >,
124125 td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint8_t >,
125126 td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int16_t >,
@@ -224,7 +225,7 @@ struct TypePairSupportDataForProductReductionTemps
224225 outTy,
225226 std::complex <double >>,
226227
227- // fall-throug
228+ // fall-through
228229 td_ns::NotDefinedEntry>::is_defined;
229230};
230231
@@ -255,7 +256,9 @@ struct ProductOverAxisTempsStridedFactory
255256 if constexpr (TypePairSupportDataForProductReductionTemps<
256257 srcTy, dstTy>::is_defined)
257258 {
258- using ReductionOpT = sycl::multiplies<dstTy>;
259+ using ReductionOpT = std::conditional_t <std::is_same_v<dstTy, bool >,
260+ sycl::logical_and<dstTy>,
261+ sycl::multiplies<dstTy>>;
259262 return dpctl::tensor::kernels::
260263 reduction_over_group_temps_strided_impl<srcTy, dstTy,
261264 ReductionOpT>;
@@ -312,7 +315,9 @@ struct ProductOverAxis1TempsContigFactory
312315 if constexpr (TypePairSupportDataForProductReductionTemps<
313316 srcTy, dstTy>::is_defined)
314317 {
315- using ReductionOpT = sycl::multiplies<dstTy>;
318+ using ReductionOpT = std::conditional_t <std::is_same_v<dstTy, bool >,
319+ sycl::logical_and<dstTy>,
320+ sycl::multiplies<dstTy>>;
316321 return dpctl::tensor::kernels::
317322 reduction_axis1_over_group_temps_contig_impl<srcTy, dstTy,
318323 ReductionOpT>;
@@ -331,7 +336,9 @@ struct ProductOverAxis0TempsContigFactory
331336 if constexpr (TypePairSupportDataForProductReductionTemps<
332337 srcTy, dstTy>::is_defined)
333338 {
334- using ReductionOpT = sycl::multiplies<dstTy>;
339+ using ReductionOpT = std::conditional_t <std::is_same_v<dstTy, bool >,
340+ sycl::logical_and<dstTy>,
341+ sycl::multiplies<dstTy>>;
335342 return dpctl::tensor::kernels::
336343 reduction_axis0_over_group_temps_contig_impl<srcTy, dstTy,
337344 ReductionOpT>;
0 commit comments