Skip to content

Commit b8ad5ec

Browse files
Move _isin to _tensor_sorting_impl
1 parent 9766d34 commit b8ad5ec

9 files changed

Lines changed: 623 additions & 11 deletions

File tree

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ set(_accumulator_sources
7070
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp
7171
)
7272
set(_sorting_sources
73-
#{CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/isin.cpp
73+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/isin.cpp
7474
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp
7575
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp
7676
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===----------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines kernels for tensor membership operations.
33+
//===----------------------------------------------------------------------===//
34+
35+
#pragma once
36+
37+
#include <cstddef>
38+
#include <vector>
39+
40+
#include <sycl/sycl.hpp>
41+
42+
#include "kernels/dpctl_tensor_types.hpp"
43+
#include "kernels/sorting/search_sorted_detail.hpp"
44+
#include "utils/offset_utils.hpp"
45+
#include "utils/rich_comparisons.hpp"
46+
47+
namespace dpctl::tensor::kernels
48+
{
49+
50+
using dpctl::tensor::ssize_t;
51+
52+
template <typename T,
53+
typename HayIndexerT,
54+
typename NeedlesIndexerT,
55+
typename OutIndexerT>
56+
struct IsinFunctor
57+
{
58+
private:
59+
bool invert;
60+
const T *hay_tp;
61+
const T *needles_tp;
62+
bool *out_tp;
63+
std::size_t hay_nelems;
64+
HayIndexerT hay_indexer;
65+
NeedlesIndexerT needles_indexer;
66+
OutIndexerT out_indexer;
67+
68+
public:
69+
IsinFunctor(const bool invert_,
70+
const T *hay_,
71+
const T *needles_,
72+
bool *out_,
73+
const std::size_t hay_nelems_,
74+
const HayIndexerT &hay_indexer_,
75+
const NeedlesIndexerT &needles_indexer_,
76+
const OutIndexerT &out_indexer_)
77+
: invert(invert_), hay_tp(hay_), needles_tp(needles_), out_tp(out_),
78+
hay_nelems(hay_nelems_), hay_indexer(hay_indexer_),
79+
needles_indexer(needles_indexer_), out_indexer(out_indexer_)
80+
{
81+
}
82+
83+
void operator()(sycl::id<1> id) const
84+
{
85+
using Compare =
86+
typename dpctl::tensor::rich_comparisons::AscendingSorter<T>::type;
87+
static constexpr Compare comp{};
88+
89+
const std::size_t i = id[0];
90+
const T needle_v = needles_tp[needles_indexer(i)];
91+
92+
// position of the needle_v in the hay array
93+
std::size_t pos{};
94+
95+
static constexpr std::size_t zero(0);
96+
// search in hay in left-closed interval, give `pos` such that
97+
// hay[pos - 1] < needle_v <= hay[pos]
98+
99+
// lower_bound returns the first pos such that bool(hay[pos] <
100+
// needle_v) is false, i.e. needle_v <= hay[pos]
101+
pos = search_sorted_detail::lower_bound_indexed_impl(
102+
hay_tp, zero, hay_nelems, needle_v, comp, hay_indexer);
103+
bool out = (pos == hay_nelems ? false : hay_tp[pos] == needle_v);
104+
out_tp[out_indexer(i)] = (invert) ? !out : out;
105+
}
106+
};
107+
108+
typedef sycl::event (*isin_contig_impl_fp_ptr_t)(
109+
sycl::queue &,
110+
const bool,
111+
const std::size_t,
112+
const std::size_t,
113+
const char *,
114+
const ssize_t,
115+
const char *,
116+
const ssize_t,
117+
char *,
118+
const ssize_t,
119+
const std::vector<sycl::event> &);
120+
121+
template <typename T>
122+
class isin_contig_impl_krn;
123+
124+
template <typename T>
125+
sycl::event isin_contig_impl(sycl::queue &exec_q,
126+
const bool invert,
127+
const std::size_t hay_nelems,
128+
const std::size_t needles_nelems,
129+
const char *hay_cp,
130+
const ssize_t hay_offset,
131+
const char *needles_cp,
132+
const ssize_t needles_offset,
133+
char *out_cp,
134+
const ssize_t out_offset,
135+
const std::vector<sycl::event> &depends)
136+
{
137+
const T *hay_tp = reinterpret_cast<const T *>(hay_cp) + hay_offset;
138+
const T *needles_tp =
139+
reinterpret_cast<const T *>(needles_cp) + needles_offset;
140+
141+
bool *out_tp = reinterpret_cast<bool *>(out_cp) + out_offset;
142+
143+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
144+
cgh.depends_on(depends);
145+
146+
using KernelName = class isin_contig_impl_krn<T>;
147+
148+
sycl::range<1> gRange(needles_nelems);
149+
150+
using TrivialIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
151+
152+
static constexpr TrivialIndexerT hay_indexer{};
153+
static constexpr TrivialIndexerT needles_indexer{};
154+
static constexpr TrivialIndexerT out_indexer{};
155+
156+
const auto fnctr =
157+
IsinFunctor<T, TrivialIndexerT, TrivialIndexerT, TrivialIndexerT>(
158+
invert, hay_tp, needles_tp, out_tp, hay_nelems, hay_indexer,
159+
needles_indexer, out_indexer);
160+
161+
cgh.parallel_for<KernelName>(gRange, fnctr);
162+
});
163+
164+
return comp_ev;
165+
}
166+
167+
typedef sycl::event (*isin_strided_impl_fp_ptr_t)(
168+
sycl::queue &,
169+
const bool,
170+
const std::size_t,
171+
const std::size_t,
172+
const char *,
173+
const ssize_t,
174+
const ssize_t,
175+
const char *,
176+
const ssize_t,
177+
char *,
178+
const ssize_t,
179+
int,
180+
const ssize_t *,
181+
const std::vector<sycl::event> &);
182+
183+
template <typename T>
184+
class isin_strided_impl_krn;
185+
186+
template <typename T>
187+
sycl::event isin_strided_impl(
188+
sycl::queue &exec_q,
189+
const bool invert,
190+
const std::size_t hay_nelems,
191+
const std::size_t needles_nelems,
192+
const char *hay_cp,
193+
const ssize_t hay_offset,
194+
// hay is 1D, so hay_nelems, hay_offset, hay_stride describe strided array
195+
const ssize_t hay_stride,
196+
const char *needles_cp,
197+
const ssize_t needles_offset,
198+
char *out_cp,
199+
const ssize_t out_offset,
200+
const int needles_nd,
201+
// packed_shape_strides is [needles_shape, needles_strides,
202+
// out_strides] has length of 3*needles_nd
203+
const ssize_t *packed_shape_strides,
204+
const std::vector<sycl::event> &depends)
205+
{
206+
const T *hay_tp = reinterpret_cast<const T *>(hay_cp);
207+
const T *needles_tp = reinterpret_cast<const T *>(needles_cp);
208+
209+
bool *out_tp = reinterpret_cast<bool *>(out_cp);
210+
211+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
212+
cgh.depends_on(depends);
213+
214+
sycl::range<1> gRange(needles_nelems);
215+
216+
using HayIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer;
217+
const HayIndexerT hay_indexer(
218+
/* offset */ hay_offset,
219+
/* size */ hay_nelems,
220+
/* step */ hay_stride);
221+
222+
using NeedlesIndexerT = dpctl::tensor::offset_utils::StridedIndexer;
223+
const ssize_t *needles_shape_strides = packed_shape_strides;
224+
const NeedlesIndexerT needles_indexer(needles_nd, needles_offset,
225+
needles_shape_strides);
226+
using OutIndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer;
227+
228+
const ssize_t *out_shape = packed_shape_strides;
229+
const ssize_t *out_strides = packed_shape_strides + 2 * needles_nd;
230+
const OutIndexerT out_indexer(needles_nd, out_offset, out_shape,
231+
out_strides);
232+
233+
const auto fnctr =
234+
IsinFunctor<T, HayIndexerT, NeedlesIndexerT, OutIndexerT>(
235+
invert, hay_tp, needles_tp, out_tp, hay_nelems, hay_indexer,
236+
needles_indexer, out_indexer);
237+
using KernelName = class isin_strided_impl_krn<T>;
238+
239+
cgh.parallel_for<KernelName>(gRange, fnctr);
240+
});
241+
242+
return comp_ev;
243+
}
244+
245+
} // namespace dpctl::tensor::kernels

dpctl_ext/tensor/libtensor/include/kernels/sorting/merge_sort.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
//===----------------------------------------------------------------------===//
3030
///
3131
/// \file
32-
/// This file defines functions of dpctl.tensor._tensor_sorting_impl
33-
/// extension.
32+
/// This file defines kernels for tensor sort/argsort operations.
3433
//===----------------------------------------------------------------------===//
3534

3635
#pragma once

dpctl_ext/tensor/libtensor/include/kernels/sorting/search_sorted_detail.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
//===----------------------------------------------------------------------===//
3030
///
3131
/// \file
32-
/// This file defines functions of dpctl.tensor._tensor_sorting_impl
33-
/// extension.
32+
/// This file defines kernels for tensor sort/argsort operations.
3433
//===----------------------------------------------------------------------===//
3534

3635
#pragma once

dpctl_ext/tensor/libtensor/include/kernels/sorting/searchsorted.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
//===----------------------------------------------------------------------===//
3030
///
3131
/// \file
32-
/// This file defines functions of dpctl.tensor._tensor_sorting_impl
33-
/// extension.
32+
/// This file defines kernels for tensor sort/argsort operations.
3433
//===----------------------------------------------------------------------===//
3534

3635
#pragma once

dpctl_ext/tensor/libtensor/include/kernels/sorting/sort_utils.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
//===----------------------------------------------------------------------===//
3030
///
3131
/// \file
32-
/// This file defines functions of dpctl.tensor._tensor_sorting_impl
33-
/// extension.
32+
/// This file defines kernels for tensor sort/argsort operations.
3433
//===----------------------------------------------------------------------===//
3534

3635
#pragma once

0 commit comments

Comments
 (0)