@@ -2284,17 +2284,17 @@ def dpnp_multiply(x1, x2, out=None, order="K"):
22842284def dpnp_negative (x , out = None , order = "K" ):
22852285 """Invokes negative() from dpctl.tensor implementation for negative() function."""
22862286
2287- # TODO: discuss with dpctl if the check is needed to be moved there
2287+ # dpctl.tensor only works with usm_ndarray
2288+ x1_usm = dpnp .get_usm_ndarray (x )
2289+ out_usm = None if out is None else dpnp .get_usm_ndarray (out )
2290+
2291+ # TODO: discuss with dpctl if the check is needed to be moved out of there
22882292 if not dpnp .isscalar (x ) and x .dtype == dpnp .bool :
22892293 raise TypeError (
22902294 "DPNP boolean negative, the `-` operator, is not supported, "
22912295 "use the `~` operator or the logical_not function instead."
22922296 )
22932297
2294- # dpctl.tensor only works with usm_ndarray
2295- x1_usm = dpnp .get_usm_ndarray (x )
2296- out_usm = None if out is None else dpnp .get_usm_ndarray (out )
2297-
22982298 res_usm = negative_func (x1_usm , out = out_usm , order = order )
22992299 return _get_result (res_usm , out = out )
23002300
@@ -2966,24 +2966,31 @@ def dpnp_subtract(x1, x2, out=None, order="K"):
29662966 Invokes sub() function from pybind11 extension of OneMKL VM if possible.
29672967
29682968 Otherwise fully relies on dpctl.tensor implementation for subtract() function.
2969- """
29702969
2971- # TODO: discuss with dpctl if the check is needed to be moved there
2972- if (
2973- not dpnp .isscalar (x1 )
2974- and not dpnp .isscalar (x2 )
2975- and x1 .dtype == x2 .dtype == dpnp .bool
2976- ):
2977- raise TypeError (
2978- "DPNP boolean subtract, the `-` operator, is not supported, "
2979- "use the bitwise_xor, the `^` operator, or the logical_xor function instead."
2980- )
2970+ """
29812971
29822972 # dpctl.tensor only works with usm_ndarray or scalar
29832973 x1_usm_or_scalar = dpnp .get_usm_ndarray_or_scalar (x1 )
29842974 x2_usm_or_scalar = dpnp .get_usm_ndarray_or_scalar (x2 )
29852975 out_usm = None if out is None else dpnp .get_usm_ndarray (out )
29862976
2977+ # TODO: discuss with dpctl if the check is needed to be moved out of there
2978+ boolean_subtract = False
2979+ if dpnp .isscalar (x1 ):
2980+ if isinstance (x1 , bool ) and x2 .dtype == dpnp .bool :
2981+ boolean_subtract = True
2982+ elif dpnp .isscalar (x2 ):
2983+ if isinstance (x2 , bool ) and x1 .dtype == dpnp .bool :
2984+ boolean_subtract = True
2985+ elif x1 .dtype == x2 .dtype == dpnp .bool :
2986+ boolean_subtract = True
2987+
2988+ if boolean_subtract :
2989+ raise TypeError (
2990+ "DPNP boolean subtract, the `-` operator, is not supported, "
2991+ "use the bitwise_xor, the `^` operator, or the logical_xor function instead."
2992+ )
2993+
29872994 res_usm = subtract_func (
29882995 x1_usm_or_scalar , x2_usm_or_scalar , out = out_usm , order = order
29892996 )
0 commit comments