Skip to content

Commit a7c6440

Browse files
Move dpt.isin() and reuse it in dpnp
1 parent ce570a5 commit a7c6440

3 files changed

Lines changed: 172 additions & 2 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
from ._clip import clip
8585
from ._searchsorted import searchsorted
8686
from ._set_functions import (
87+
isin,
8788
unique_all,
8889
unique_counts,
8990
unique_inverse,
@@ -119,6 +120,7 @@
119120
"full_like",
120121
"iinfo",
121122
"isdtype",
123+
"isin",
122124
"linspace",
123125
"meshgrid",
124126
"moveaxis",

dpctl_ext/tensor/_set_functions.py

Lines changed: 166 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# THE POSSIBILITY OF SUCH DAMAGE.
2727
# *****************************************************************************
2828

29-
from typing import NamedTuple
29+
from typing import NamedTuple, Optional, Union
3030

3131
import dpctl.tensor as dpt
3232
import dpctl.utils as du
@@ -36,6 +36,13 @@
3636
# when dpnp fully migrates dpctl/tensor
3737
import dpctl_ext.tensor as dpt_ext
3838

39+
from ._copy_utils import _empty_like_orderK
40+
from ._scalar_utils import (
41+
_get_dtype,
42+
_get_queue_usm_type,
43+
_get_shape,
44+
_validate_dtype,
45+
)
3946
from ._tensor_impl import (
4047
_copy_usm_ndarray_into_usm_ndarray,
4148
_extract,
@@ -47,9 +54,25 @@
4754
)
4855
from ._tensor_sorting_impl import (
4956
_argsort_ascending,
57+
_isin,
5058
_searchsorted_left,
5159
_sort_ascending,
5260
)
61+
from ._type_utils import (
62+
_resolve_weak_types_all_py_ints,
63+
_to_device_supported_dtype,
64+
)
65+
66+
__all__ = [
67+
"isin",
68+
"unique_values",
69+
"unique_counts",
70+
"unique_inverse",
71+
"unique_all",
72+
"UniqueAllResult",
73+
"UniqueCountsResult",
74+
"UniqueInverseResult",
75+
]
5376

5477

5578
class UniqueAllResult(NamedTuple):
@@ -636,3 +659,145 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
636659
inv,
637660
_counts,
638661
)
662+
663+
664+
def isin(
665+
x: Union[dpt.usm_ndarray, int, float, complex, bool],
666+
test_elements: Union[dpt.usm_ndarray, int, float, complex, bool],
667+
/,
668+
*,
669+
invert: Optional[bool] = False,
670+
) -> dpt.usm_ndarray:
671+
"""isin(x, test_elements, /, *, invert=False)
672+
673+
Tests `x in test_elements` for each element of `x`. Returns a boolean array
674+
with the same shape as `x` that is `True` where the element is in
675+
`test_elements`, `False` otherwise.
676+
677+
Args:
678+
x (Union[usm_ndarray, bool, int, float, complex]):
679+
input element or elements.
680+
test_elements (Union[usm_ndarray, bool, int, float, complex]):
681+
elements against which to test each value of `x`.
682+
invert (Optional[bool]):
683+
if `True`, the output results are inverted, i.e., are equivalent to
684+
testing `x not in test_elements` for each element of `x`.
685+
Default: `False`.
686+
687+
Returns:
688+
usm_ndarray:
689+
an array of the inclusion test results. The returned array has a
690+
boolean data type and the same shape as `x`.
691+
"""
692+
q1, x_usm_type = _get_queue_usm_type(x)
693+
q2, test_usm_type = _get_queue_usm_type(test_elements)
694+
if q1 is None and q2 is None:
695+
raise du.ExecutionPlacementError(
696+
"Execution placement can not be unambiguously inferred "
697+
"from input arguments. "
698+
"One of the arguments must represent USM allocation and "
699+
"expose `__sycl_usm_array_interface__` property"
700+
)
701+
if q1 is None:
702+
exec_q = q2
703+
res_usm_type = test_usm_type
704+
elif q2 is None:
705+
exec_q = q1
706+
res_usm_type = x_usm_type
707+
else:
708+
exec_q = du.get_execution_queue((q1, q2))
709+
if exec_q is None:
710+
raise du.ExecutionPlacementError(
711+
"Execution placement can not be unambiguously inferred "
712+
"from input arguments."
713+
)
714+
res_usm_type = du.get_coerced_usm_type(
715+
(
716+
x_usm_type,
717+
test_usm_type,
718+
)
719+
)
720+
du.validate_usm_type(res_usm_type, allow_none=False)
721+
sycl_dev = exec_q.sycl_device
722+
723+
if not isinstance(invert, bool):
724+
raise TypeError(
725+
"`invert` keyword argument must be of boolean type, "
726+
f"got {type(invert)}"
727+
)
728+
729+
x_dt = _get_dtype(x, sycl_dev)
730+
test_dt = _get_dtype(test_elements, sycl_dev)
731+
if not all(_validate_dtype(dt) for dt in (x_dt, test_dt)):
732+
raise ValueError("Operands have unsupported data types")
733+
734+
x_sh = _get_shape(x)
735+
if isinstance(test_elements, dpt.usm_ndarray) and test_elements.size == 0:
736+
if invert:
737+
return dpt_ext.ones(
738+
x_sh, dtype=dpt.bool, usm_type=res_usm_type, sycl_queue=exec_q
739+
)
740+
else:
741+
return dpt_ext.zeros(
742+
x_sh, dtype=dpt.bool, usm_type=res_usm_type, sycl_queue=exec_q
743+
)
744+
745+
dt1, dt2 = _resolve_weak_types_all_py_ints(x_dt, test_dt, sycl_dev)
746+
dt = _to_device_supported_dtype(dpt_ext.result_type(dt1, dt2), sycl_dev)
747+
748+
if not isinstance(x, dpt.usm_ndarray):
749+
x_arr = dpt_ext.asarray(
750+
x, dtype=dt1, usm_type=res_usm_type, sycl_queue=exec_q
751+
)
752+
else:
753+
x_arr = x
754+
755+
if not isinstance(test_elements, dpt.usm_ndarray):
756+
test_arr = dpt_ext.asarray(
757+
test_elements, dtype=dt2, usm_type=res_usm_type, sycl_queue=exec_q
758+
)
759+
else:
760+
test_arr = test_elements
761+
762+
_manager = du.SequentialOrderManager[exec_q]
763+
dep_evs = _manager.submitted_events
764+
765+
if x_dt != dt:
766+
x_buf = _empty_like_orderK(x_arr, dt, res_usm_type, exec_q)
767+
ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
768+
src=x_arr, dst=x_buf, sycl_queue=exec_q, depends=dep_evs
769+
)
770+
_manager.add_event_pair(ht_ev, ev)
771+
else:
772+
x_buf = x_arr
773+
774+
if test_dt != dt:
775+
# copy into C-contiguous memory, because the array will be flattened
776+
test_buf = dpt_ext.empty_like(
777+
test_arr, dtype=dt, order="C", usm_type=res_usm_type
778+
)
779+
ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
780+
src=test_arr, dst=test_buf, sycl_queue=exec_q, depends=dep_evs
781+
)
782+
_manager.add_event_pair(ht_ev, ev)
783+
else:
784+
test_buf = test_arr
785+
786+
test_buf = dpt_ext.reshape(test_buf, -1)
787+
test_buf = dpt_ext.sort(test_buf)
788+
789+
dst = dpt_ext.empty_like(
790+
x_buf, dtype=dpt.bool, usm_type=res_usm_type, order="C"
791+
)
792+
793+
dep_evs = _manager.submitted_events
794+
ht_ev, s_ev = _isin(
795+
needles=x_buf,
796+
hay=test_buf,
797+
dst=dst,
798+
sycl_queue=exec_q,
799+
invert=invert,
800+
depends=dep_evs,
801+
)
802+
_manager.add_event_pair(ht_ev, s_ev)
803+
return dst

dpnp/dpnp_iface_logic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949
import dpctl.utils as dpu
5050
import numpy
5151

52+
# TODO: revert to `import dpctl.tensor...`
53+
# when dpnp fully migrates dpctl/tensor
54+
import dpctl_ext.tensor as dpt_ext
5255
import dpnp
5356
import dpnp.backend.extensions.ufunc._ufunc_impl as ufi
5457
from dpnp.dpnp_algo.dpnp_elementwise_common import DPNPBinaryFunc, DPNPUnaryFunc
@@ -1273,7 +1276,7 @@ def isin(
12731276
usm_element = dpnp.get_usm_ndarray(element)
12741277
usm_test = dpnp.get_usm_ndarray(test_elements)
12751278
return dpnp_array._create_from_usm_ndarray(
1276-
dpt.isin(
1279+
dpt_ext.isin(
12771280
usm_element,
12781281
usm_test,
12791282
invert=invert,

0 commit comments

Comments
 (0)