Skip to content

Commit 88a23a2

Browse files
Move _searchsorted_left/right to _tensor_sorting_impl
1 parent 6912311 commit 88a23a2

5 files changed

Lines changed: 788 additions & 3 deletions

File tree

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ set(_sorting_sources
7575
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp
7676
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp
7777
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_argsort.cpp
78-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
78+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
7979
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/topk.cpp
8080
)
8181
set(_tensor_accumulation_impl_sources
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===----------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_sorting_impl
33+
/// extension.
34+
//===----------------------------------------------------------------------===//
35+
36+
#pragma once
37+
38+
#include <cstddef>
39+
#include <cstdint>
40+
#include <vector>
41+
42+
#include <sycl/sycl.hpp>
43+
44+
#include "kernels/dpctl_tensor_types.hpp"
45+
#include "kernels/sorting/search_sorted_detail.hpp"
46+
#include "utils/offset_utils.hpp"
47+
48+
namespace dpctl::tensor::kernels
49+
{
50+
51+
using dpctl::tensor::ssize_t;
52+
53+
template <typename argTy,
54+
typename indTy,
55+
bool left_side,
56+
typename HayIndexerT,
57+
typename NeedlesIndexerT,
58+
typename PositionsIndexerT,
59+
typename Compare>
60+
struct SearchSortedFunctor
61+
{
62+
private:
63+
const argTy *hay_tp;
64+
const argTy *needles_tp;
65+
indTy *positions_tp;
66+
std::size_t hay_nelems;
67+
HayIndexerT hay_indexer;
68+
NeedlesIndexerT needles_indexer;
69+
PositionsIndexerT positions_indexer;
70+
71+
public:
72+
SearchSortedFunctor(const argTy *hay_,
73+
const argTy *needles_,
74+
indTy *positions_,
75+
const std::size_t hay_nelems_,
76+
const HayIndexerT &hay_indexer_,
77+
const NeedlesIndexerT &needles_indexer_,
78+
const PositionsIndexerT &positions_indexer_)
79+
: hay_tp(hay_), needles_tp(needles_), positions_tp(positions_),
80+
hay_nelems(hay_nelems_), hay_indexer(hay_indexer_),
81+
needles_indexer(needles_indexer_),
82+
positions_indexer(positions_indexer_)
83+
{
84+
}
85+
86+
void operator()(sycl::id<1> id) const
87+
{
88+
const Compare comp{};
89+
90+
const std::size_t i = id[0];
91+
const argTy needle_v = needles_tp[needles_indexer(i)];
92+
93+
// position of the needle_v in the hay array
94+
indTy pos{};
95+
96+
static constexpr std::size_t zero(0);
97+
if constexpr (left_side) {
98+
// search in hay in left-closed interval, give `pos` such that
99+
// hay[pos - 1] < needle_v <= hay[pos]
100+
101+
// lower_bound returns the first pos such that bool(hay[pos] <
102+
// needle_v) is false, i.e. needle_v <= hay[pos]
103+
pos = search_sorted_detail::lower_bound_indexed_impl(
104+
hay_tp, zero, hay_nelems, needle_v, comp, hay_indexer);
105+
}
106+
else {
107+
// search in hay in right-closed interval: hay[pos - 1] <= needle_v
108+
// < hay[pos]
109+
110+
// upper_bound returns the first pos such that bool(needle_v <
111+
// hay[pos]) is true, i.e. needle_v < hay[pos]
112+
pos = search_sorted_detail::upper_bound_indexed_impl(
113+
hay_tp, zero, hay_nelems, needle_v, comp, hay_indexer);
114+
}
115+
116+
positions_tp[positions_indexer(i)] = pos;
117+
}
118+
};
119+
120+
typedef sycl::event (*searchsorted_contig_impl_fp_ptr_t)(
121+
sycl::queue &,
122+
const std::size_t,
123+
const std::size_t,
124+
const char *,
125+
const ssize_t,
126+
const char *,
127+
const ssize_t,
128+
char *,
129+
const ssize_t,
130+
const std::vector<sycl::event> &);
131+
132+
template <typename T1, typename T2, bool left_closed>
133+
class searchsorted_contig_impl_krn;
134+
135+
template <typename argTy, typename indTy, bool left_closed, typename Compare>
136+
sycl::event searchsorted_contig_impl(sycl::queue &exec_q,
137+
const std::size_t hay_nelems,
138+
const std::size_t needles_nelems,
139+
const char *hay_cp,
140+
const ssize_t hay_offset,
141+
const char *needles_cp,
142+
const ssize_t needles_offset,
143+
char *positions_cp,
144+
const ssize_t positions_offset,
145+
const std::vector<sycl::event> &depends)
146+
{
147+
const argTy *hay_tp = reinterpret_cast<const argTy *>(hay_cp) + hay_offset;
148+
const argTy *needles_tp =
149+
reinterpret_cast<const argTy *>(needles_cp) + needles_offset;
150+
151+
indTy *positions_tp =
152+
reinterpret_cast<indTy *>(positions_cp) + positions_offset;
153+
154+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
155+
cgh.depends_on(depends);
156+
157+
using KernelName =
158+
class searchsorted_contig_impl_krn<argTy, indTy, left_closed>;
159+
160+
sycl::range<1> gRange(needles_nelems);
161+
162+
using TrivialIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
163+
164+
static constexpr TrivialIndexerT hay_indexer{};
165+
static constexpr TrivialIndexerT needles_indexer{};
166+
static constexpr TrivialIndexerT positions_indexer{};
167+
168+
const auto fnctr =
169+
SearchSortedFunctor<argTy, indTy, left_closed, TrivialIndexerT,
170+
TrivialIndexerT, TrivialIndexerT, Compare>(
171+
hay_tp, needles_tp, positions_tp, hay_nelems, hay_indexer,
172+
needles_indexer, positions_indexer);
173+
174+
cgh.parallel_for<KernelName>(gRange, fnctr);
175+
});
176+
177+
return comp_ev;
178+
}
179+
180+
typedef sycl::event (*searchsorted_strided_impl_fp_ptr_t)(
181+
sycl::queue &,
182+
const std::size_t,
183+
const std::size_t,
184+
const char *,
185+
const ssize_t,
186+
const ssize_t,
187+
const char *,
188+
const ssize_t,
189+
char *,
190+
const ssize_t,
191+
int,
192+
const ssize_t *,
193+
const std::vector<sycl::event> &);
194+
195+
template <typename T1, typename T2, bool left_closed>
196+
class searchsorted_strided_impl_krn;
197+
198+
template <typename argTy, typename indTy, bool left_closed, typename Compare>
199+
sycl::event searchsorted_strided_impl(
200+
sycl::queue &exec_q,
201+
const std::size_t hay_nelems,
202+
const std::size_t needles_nelems,
203+
const char *hay_cp,
204+
const ssize_t hay_offset,
205+
// hay is 1D, so hay_nelems, hay_offset, hay_stride describe strided array
206+
const ssize_t hay_stride,
207+
const char *needles_cp,
208+
const ssize_t needles_offset,
209+
char *positions_cp,
210+
const ssize_t positions_offset,
211+
const int needles_nd,
212+
// packed_shape_strides is [needles_shape, needles_strides,
213+
// positions_strides] has length of 3*needles_nd
214+
const ssize_t *packed_shape_strides,
215+
const std::vector<sycl::event> &depends)
216+
{
217+
const argTy *hay_tp = reinterpret_cast<const argTy *>(hay_cp);
218+
const argTy *needles_tp = reinterpret_cast<const argTy *>(needles_cp);
219+
220+
indTy *positions_tp = reinterpret_cast<indTy *>(positions_cp);
221+
222+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
223+
cgh.depends_on(depends);
224+
225+
sycl::range<1> gRange(needles_nelems);
226+
227+
using HayIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer;
228+
const HayIndexerT hay_indexer(
229+
/* offset */ hay_offset,
230+
/* size */ hay_nelems,
231+
/* step */ hay_stride);
232+
233+
using NeedlesIndexerT = dpctl::tensor::offset_utils::StridedIndexer;
234+
const ssize_t *needles_shape_strides = packed_shape_strides;
235+
const NeedlesIndexerT needles_indexer(needles_nd, needles_offset,
236+
needles_shape_strides);
237+
using PositionsIndexerT =
238+
dpctl::tensor::offset_utils::UnpackedStridedIndexer;
239+
240+
const ssize_t *positions_shape = packed_shape_strides;
241+
const ssize_t *positions_strides =
242+
packed_shape_strides + 2 * needles_nd;
243+
const PositionsIndexerT positions_indexer(
244+
needles_nd, positions_offset, positions_shape, positions_strides);
245+
246+
const auto fnctr =
247+
SearchSortedFunctor<argTy, indTy, left_closed, HayIndexerT,
248+
NeedlesIndexerT, PositionsIndexerT, Compare>(
249+
hay_tp, needles_tp, positions_tp, hay_nelems, hay_indexer,
250+
needles_indexer, positions_indexer);
251+
using KernelName =
252+
class searchsorted_strided_impl_krn<argTy, indTy, left_closed>;
253+
254+
cgh.parallel_for<KernelName>(gRange, fnctr);
255+
});
256+
257+
return comp_ev;
258+
}
259+
260+
} // namespace dpctl::tensor::kernels

0 commit comments

Comments
 (0)