|
26 | 26 | # THE POSSIBILITY OF SUCH DAMAGE. |
27 | 27 | # ***************************************************************************** |
28 | 28 |
|
29 | | -from typing import NamedTuple |
| 29 | +from typing import NamedTuple, Optional, Union |
30 | 30 |
|
31 | 31 | import dpctl.tensor as dpt |
32 | 32 | import dpctl.utils as du |
|
36 | 36 | # when dpnp fully migrates dpctl/tensor |
37 | 37 | import dpctl_ext.tensor as dpt_ext |
38 | 38 |
|
| 39 | +from ._copy_utils import _empty_like_orderK |
| 40 | +from ._scalar_utils import ( |
| 41 | + _get_dtype, |
| 42 | + _get_queue_usm_type, |
| 43 | + _get_shape, |
| 44 | + _validate_dtype, |
| 45 | +) |
39 | 46 | from ._tensor_impl import ( |
40 | 47 | _copy_usm_ndarray_into_usm_ndarray, |
41 | 48 | _extract, |
|
47 | 54 | ) |
48 | 55 | from ._tensor_sorting_impl import ( |
49 | 56 | _argsort_ascending, |
| 57 | + _isin, |
50 | 58 | _searchsorted_left, |
51 | 59 | _sort_ascending, |
52 | 60 | ) |
| 61 | +from ._type_utils import ( |
| 62 | + _resolve_weak_types_all_py_ints, |
| 63 | + _to_device_supported_dtype, |
| 64 | +) |
| 65 | + |
| 66 | +__all__ = [ |
| 67 | + "isin", |
| 68 | + "unique_values", |
| 69 | + "unique_counts", |
| 70 | + "unique_inverse", |
| 71 | + "unique_all", |
| 72 | + "UniqueAllResult", |
| 73 | + "UniqueCountsResult", |
| 74 | + "UniqueInverseResult", |
| 75 | +] |
53 | 76 |
|
54 | 77 |
|
55 | 78 | class UniqueAllResult(NamedTuple): |
@@ -636,3 +659,145 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult: |
636 | 659 | inv, |
637 | 660 | _counts, |
638 | 661 | ) |
| 662 | + |
| 663 | + |
| 664 | +def isin( |
| 665 | + x: Union[dpt.usm_ndarray, int, float, complex, bool], |
| 666 | + test_elements: Union[dpt.usm_ndarray, int, float, complex, bool], |
| 667 | + /, |
| 668 | + *, |
| 669 | + invert: Optional[bool] = False, |
| 670 | +) -> dpt.usm_ndarray: |
| 671 | + """isin(x, test_elements, /, *, invert=False) |
| 672 | +
|
| 673 | + Tests `x in test_elements` for each element of `x`. Returns a boolean array |
| 674 | + with the same shape as `x` that is `True` where the element is in |
| 675 | + `test_elements`, `False` otherwise. |
| 676 | +
|
| 677 | + Args: |
| 678 | + x (Union[usm_ndarray, bool, int, float, complex]): |
| 679 | + input element or elements. |
| 680 | + test_elements (Union[usm_ndarray, bool, int, float, complex]): |
| 681 | + elements against which to test each value of `x`. |
| 682 | + invert (Optional[bool]): |
| 683 | + if `True`, the output results are inverted, i.e., are equivalent to |
| 684 | + testing `x not in test_elements` for each element of `x`. |
| 685 | + Default: `False`. |
| 686 | +
|
| 687 | + Returns: |
| 688 | + usm_ndarray: |
| 689 | + an array of the inclusion test results. The returned array has a |
| 690 | + boolean data type and the same shape as `x`. |
| 691 | + """ |
| 692 | + q1, x_usm_type = _get_queue_usm_type(x) |
| 693 | + q2, test_usm_type = _get_queue_usm_type(test_elements) |
| 694 | + if q1 is None and q2 is None: |
| 695 | + raise du.ExecutionPlacementError( |
| 696 | + "Execution placement can not be unambiguously inferred " |
| 697 | + "from input arguments. " |
| 698 | + "One of the arguments must represent USM allocation and " |
| 699 | + "expose `__sycl_usm_array_interface__` property" |
| 700 | + ) |
| 701 | + if q1 is None: |
| 702 | + exec_q = q2 |
| 703 | + res_usm_type = test_usm_type |
| 704 | + elif q2 is None: |
| 705 | + exec_q = q1 |
| 706 | + res_usm_type = x_usm_type |
| 707 | + else: |
| 708 | + exec_q = du.get_execution_queue((q1, q2)) |
| 709 | + if exec_q is None: |
| 710 | + raise du.ExecutionPlacementError( |
| 711 | + "Execution placement can not be unambiguously inferred " |
| 712 | + "from input arguments." |
| 713 | + ) |
| 714 | + res_usm_type = du.get_coerced_usm_type( |
| 715 | + ( |
| 716 | + x_usm_type, |
| 717 | + test_usm_type, |
| 718 | + ) |
| 719 | + ) |
| 720 | + du.validate_usm_type(res_usm_type, allow_none=False) |
| 721 | + sycl_dev = exec_q.sycl_device |
| 722 | + |
| 723 | + if not isinstance(invert, bool): |
| 724 | + raise TypeError( |
| 725 | + "`invert` keyword argument must be of boolean type, " |
| 726 | + f"got {type(invert)}" |
| 727 | + ) |
| 728 | + |
| 729 | + x_dt = _get_dtype(x, sycl_dev) |
| 730 | + test_dt = _get_dtype(test_elements, sycl_dev) |
| 731 | + if not all(_validate_dtype(dt) for dt in (x_dt, test_dt)): |
| 732 | + raise ValueError("Operands have unsupported data types") |
| 733 | + |
| 734 | + x_sh = _get_shape(x) |
| 735 | + if isinstance(test_elements, dpt.usm_ndarray) and test_elements.size == 0: |
| 736 | + if invert: |
| 737 | + return dpt_ext.ones( |
| 738 | + x_sh, dtype=dpt.bool, usm_type=res_usm_type, sycl_queue=exec_q |
| 739 | + ) |
| 740 | + else: |
| 741 | + return dpt_ext.zeros( |
| 742 | + x_sh, dtype=dpt.bool, usm_type=res_usm_type, sycl_queue=exec_q |
| 743 | + ) |
| 744 | + |
| 745 | + dt1, dt2 = _resolve_weak_types_all_py_ints(x_dt, test_dt, sycl_dev) |
| 746 | + dt = _to_device_supported_dtype(dpt_ext.result_type(dt1, dt2), sycl_dev) |
| 747 | + |
| 748 | + if not isinstance(x, dpt.usm_ndarray): |
| 749 | + x_arr = dpt_ext.asarray( |
| 750 | + x, dtype=dt1, usm_type=res_usm_type, sycl_queue=exec_q |
| 751 | + ) |
| 752 | + else: |
| 753 | + x_arr = x |
| 754 | + |
| 755 | + if not isinstance(test_elements, dpt.usm_ndarray): |
| 756 | + test_arr = dpt_ext.asarray( |
| 757 | + test_elements, dtype=dt2, usm_type=res_usm_type, sycl_queue=exec_q |
| 758 | + ) |
| 759 | + else: |
| 760 | + test_arr = test_elements |
| 761 | + |
| 762 | + _manager = du.SequentialOrderManager[exec_q] |
| 763 | + dep_evs = _manager.submitted_events |
| 764 | + |
| 765 | + if x_dt != dt: |
| 766 | + x_buf = _empty_like_orderK(x_arr, dt, res_usm_type, exec_q) |
| 767 | + ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray( |
| 768 | + src=x_arr, dst=x_buf, sycl_queue=exec_q, depends=dep_evs |
| 769 | + ) |
| 770 | + _manager.add_event_pair(ht_ev, ev) |
| 771 | + else: |
| 772 | + x_buf = x_arr |
| 773 | + |
| 774 | + if test_dt != dt: |
| 775 | + # copy into C-contiguous memory, because the array will be flattened |
| 776 | + test_buf = dpt_ext.empty_like( |
| 777 | + test_arr, dtype=dt, order="C", usm_type=res_usm_type |
| 778 | + ) |
| 779 | + ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray( |
| 780 | + src=test_arr, dst=test_buf, sycl_queue=exec_q, depends=dep_evs |
| 781 | + ) |
| 782 | + _manager.add_event_pair(ht_ev, ev) |
| 783 | + else: |
| 784 | + test_buf = test_arr |
| 785 | + |
| 786 | + test_buf = dpt_ext.reshape(test_buf, -1) |
| 787 | + test_buf = dpt_ext.sort(test_buf) |
| 788 | + |
| 789 | + dst = dpt_ext.empty_like( |
| 790 | + x_buf, dtype=dpt.bool, usm_type=res_usm_type, order="C" |
| 791 | + ) |
| 792 | + |
| 793 | + dep_evs = _manager.submitted_events |
| 794 | + ht_ev, s_ev = _isin( |
| 795 | + needles=x_buf, |
| 796 | + hay=test_buf, |
| 797 | + dst=dst, |
| 798 | + sycl_queue=exec_q, |
| 799 | + invert=invert, |
| 800 | + depends=dep_evs, |
| 801 | + ) |
| 802 | + _manager.add_event_pair(ht_ev, s_ev) |
| 803 | + return dst |
0 commit comments