Skip to content

Commit 3c428a6

Browse files
Switch fully to dpctl_ext.tensor in dpctl_ext.tensor
1 parent 39c0571 commit 3c428a6

19 files changed

Lines changed: 387 additions & 424 deletions

dpctl_ext/tensor/_accumulation.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,11 @@
2727
# *****************************************************************************
2828

2929
import dpctl
30-
import dpctl.tensor as dpt
3130
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
3231

3332
# TODO: revert to `import dpctl.tensor...`
3433
# when dpnp fully migrates dpctl/tensor
35-
import dpctl_ext.tensor as dpt_ext
34+
import dpctl_ext.tensor as dpt
3635
import dpctl_ext.tensor._tensor_accumulation_impl as tai
3736
import dpctl_ext.tensor._tensor_impl as ti
3837

@@ -82,7 +81,7 @@ def _accumulate_common(
8281
perm = [i for i in range(nd) if i != axis] + [
8382
axis,
8483
]
85-
arr = dpt_ext.permute_dims(x, perm)
84+
arr = dpt.permute_dims(x, perm)
8685
q = x.sycl_queue
8786
inp_dt = x.dtype
8887
res_usm_type = x.usm_type
@@ -130,16 +129,16 @@ def _accumulate_common(
130129
)
131130
# permute out array dims if necessary
132131
if a1 != nd:
133-
out = dpt_ext.permute_dims(out, perm)
132+
out = dpt.permute_dims(out, perm)
134133
orig_out = out
135134
if ti._array_overlap(x, out) and implemented_types:
136-
out = dpt_ext.empty_like(out)
135+
out = dpt.empty_like(out)
137136
else:
138-
out = dpt_ext.empty(
137+
out = dpt.empty(
139138
res_sh, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
140139
)
141140
if a1 != nd:
142-
out = dpt_ext.permute_dims(out, perm)
141+
out = dpt.permute_dims(out, perm)
143142

144143
_manager = SequentialOrderManager[q]
145144
depends = _manager.submitted_events
@@ -166,7 +165,7 @@ def _accumulate_common(
166165
out = orig_out
167166
else:
168167
if _dtype_supported(res_dt, res_dt):
169-
tmp = dpt_ext.empty(
168+
tmp = dpt.empty(
170169
arr.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
171170
)
172171
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
@@ -191,18 +190,18 @@ def _accumulate_common(
191190
_manager.add_event_pair(ht_e, acc_ev)
192191
else:
193192
buf_dt = _default_accumulation_type_fn(inp_dt, q)
194-
tmp = dpt_ext.empty(
193+
tmp = dpt.empty(
195194
arr.shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q
196195
)
197196
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
198197
src=arr, dst=tmp, sycl_queue=q, depends=depends
199198
)
200199
_manager.add_event_pair(ht_e_cpy, cpy_e)
201-
tmp_res = dpt_ext.empty(
200+
tmp_res = dpt.empty(
202201
res_sh, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q
203202
)
204203
if a1 != nd:
205-
tmp_res = dpt_ext.permute_dims(tmp_res, perm)
204+
tmp_res = dpt.permute_dims(tmp_res, perm)
206205
if not include_initial:
207206
ht_e, acc_ev = _accumulate_fn(
208207
src=tmp,
@@ -225,10 +224,10 @@ def _accumulate_common(
225224
_manager.add_event_pair(ht_e_cpy2, cpy_e2)
226225

227226
if appended_axis:
228-
out = dpt_ext.squeeze(out)
227+
out = dpt.squeeze(out)
229228
if a1 != nd:
230229
inv_perm = sorted(range(nd), key=lambda d: perm[d])
231-
out = dpt_ext.permute_dims(out, inv_perm)
230+
out = dpt.permute_dims(out, inv_perm)
232231

233232
return out
234233

dpctl_ext/tensor/_clip.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,11 @@
2727
# *****************************************************************************
2828

2929
import dpctl
30-
import dpctl.tensor as dpt
3130
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
3231

3332
# TODO: revert to `import dpctl.tensor...`
3433
# when dpnp fully migrates dpctl/tensor
35-
import dpctl_ext.tensor as dpt_ext
34+
import dpctl_ext.tensor as dpt
3635
import dpctl_ext.tensor._tensor_elementwise_impl as tei
3736
import dpctl_ext.tensor._tensor_impl as ti
3837

@@ -163,20 +162,20 @@ def _clip_none(x, val, out, order, _binary_fn):
163162

164163
if ti._array_overlap(x, out):
165164
if not ti._same_logical_tensors(x, out):
166-
out = dpt_ext.empty_like(out)
165+
out = dpt.empty_like(out)
167166

168167
if isinstance(val, dpt.usm_ndarray):
169168
if (
170169
ti._array_overlap(val, out)
171170
and not ti._same_logical_tensors(val, out)
172171
and val_dtype == res_dt
173172
):
174-
out = dpt_ext.empty_like(out)
173+
out = dpt.empty_like(out)
175174

176175
if isinstance(val, dpt.usm_ndarray):
177176
val_ary = val
178177
else:
179-
val_ary = dpt_ext.asarray(val, dtype=val_dtype, sycl_queue=exec_q)
178+
val_ary = dpt.asarray(val, dtype=val_dtype, sycl_queue=exec_q)
180179

181180
if order == "A":
182181
order = (
@@ -197,17 +196,17 @@ def _clip_none(x, val, out, order, _binary_fn):
197196
x, val_ary, res_dt, res_shape, res_usm_type, exec_q
198197
)
199198
else:
200-
out = dpt_ext.empty(
199+
out = dpt.empty(
201200
res_shape,
202201
dtype=res_dt,
203202
usm_type=res_usm_type,
204203
sycl_queue=exec_q,
205204
order=order,
206205
)
207206
if x_shape != res_shape:
208-
x = dpt_ext.broadcast_to(x, res_shape)
207+
x = dpt.broadcast_to(x, res_shape)
209208
if val_ary.shape != res_shape:
210-
val_ary = dpt_ext.broadcast_to(val_ary, res_shape)
209+
val_ary = dpt.broadcast_to(val_ary, res_shape)
211210
_manager = SequentialOrderManager[exec_q]
212211
dep_evs = _manager.submitted_events
213212
ht_binary_ev, binary_ev = _binary_fn(
@@ -229,7 +228,7 @@ def _clip_none(x, val, out, order, _binary_fn):
229228
if order == "K":
230229
buf = _empty_like_orderK(val_ary, res_dt)
231230
else:
232-
buf = dpt_ext.empty_like(val_ary, dtype=res_dt, order=order)
231+
buf = dpt.empty_like(val_ary, dtype=res_dt, order=order)
233232
_manager = SequentialOrderManager[exec_q]
234233
dep_evs = _manager.submitted_events
235234
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
@@ -242,7 +241,7 @@ def _clip_none(x, val, out, order, _binary_fn):
242241
x, buf, res_dt, res_shape, res_usm_type, exec_q
243242
)
244243
else:
245-
out = dpt_ext.empty(
244+
out = dpt.empty(
246245
res_shape,
247246
dtype=res_dt,
248247
usm_type=res_usm_type,
@@ -251,8 +250,8 @@ def _clip_none(x, val, out, order, _binary_fn):
251250
)
252251

253252
if x_shape != res_shape:
254-
x = dpt_ext.broadcast_to(x, res_shape)
255-
buf = dpt_ext.broadcast_to(buf, res_shape)
253+
x = dpt.broadcast_to(x, res_shape)
254+
buf = dpt.broadcast_to(buf, res_shape)
256255
ht_binary_ev, binary_ev = _binary_fn(
257256
src1=x,
258257
src2=buf,
@@ -313,9 +312,9 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
313312
if order not in ["K", "C", "F", "A"]:
314313
order = "K"
315314
if x.dtype.kind in "iu":
316-
if isinstance(min, int) and min <= dpt_ext.iinfo(x.dtype).min:
315+
if isinstance(min, int) and min <= dpt.iinfo(x.dtype).min:
317316
min = None
318-
if isinstance(max, int) and max >= dpt_ext.iinfo(x.dtype).max:
317+
if isinstance(max, int) and max >= dpt.iinfo(x.dtype).max:
319318
max = None
320319
if min is None and max is None:
321320
exec_q = x.sycl_queue
@@ -353,14 +352,14 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
353352

354353
if ti._array_overlap(x, out):
355354
if not ti._same_logical_tensors(x, out):
356-
out = dpt_ext.empty_like(out)
355+
out = dpt.empty_like(out)
357356
else:
358357
return out
359358
else:
360359
if order == "K":
361360
out = _empty_like_orderK(x, x.dtype)
362361
else:
363-
out = dpt_ext.empty_like(x, order=order)
362+
out = dpt.empty_like(x, order=order)
364363

365364
_manager = SequentialOrderManager[exec_q]
366365
dep_evs = _manager.submitted_events
@@ -519,32 +518,32 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
519518

520519
if ti._array_overlap(x, out):
521520
if not ti._same_logical_tensors(x, out):
522-
out = dpt_ext.empty_like(out)
521+
out = dpt.empty_like(out)
523522

524523
if isinstance(min, dpt.usm_ndarray):
525524
if (
526525
ti._array_overlap(min, out)
527526
and not ti._same_logical_tensors(min, out)
528527
and buf1_dt is None
529528
):
530-
out = dpt_ext.empty_like(out)
529+
out = dpt.empty_like(out)
531530

532531
if isinstance(max, dpt.usm_ndarray):
533532
if (
534533
ti._array_overlap(max, out)
535534
and not ti._same_logical_tensors(max, out)
536535
and buf2_dt is None
537536
):
538-
out = dpt_ext.empty_like(out)
537+
out = dpt.empty_like(out)
539538

540539
if isinstance(min, dpt.usm_ndarray):
541540
a_min = min
542541
else:
543-
a_min = dpt_ext.asarray(min, dtype=min_dtype, sycl_queue=exec_q)
542+
a_min = dpt.asarray(min, dtype=min_dtype, sycl_queue=exec_q)
544543
if isinstance(max, dpt.usm_ndarray):
545544
a_max = max
546545
else:
547-
a_max = dpt_ext.asarray(max, dtype=max_dtype, sycl_queue=exec_q)
546+
a_max = dpt.asarray(max, dtype=max_dtype, sycl_queue=exec_q)
548547

549548
if order == "A":
550549
order = (
@@ -572,19 +571,19 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
572571
exec_q,
573572
)
574573
else:
575-
out = dpt_ext.empty(
574+
out = dpt.empty(
576575
res_shape,
577576
dtype=res_dt,
578577
usm_type=res_usm_type,
579578
sycl_queue=exec_q,
580579
order=order,
581580
)
582581
if x_shape != res_shape:
583-
x = dpt_ext.broadcast_to(x, res_shape)
582+
x = dpt.broadcast_to(x, res_shape)
584583
if a_min.shape != res_shape:
585-
a_min = dpt_ext.broadcast_to(a_min, res_shape)
584+
a_min = dpt.broadcast_to(a_min, res_shape)
586585
if a_max.shape != res_shape:
587-
a_max = dpt_ext.broadcast_to(a_max, res_shape)
586+
a_max = dpt.broadcast_to(a_max, res_shape)
588587
_manager = SequentialOrderManager[exec_q]
589588
dep_ev = _manager.submitted_events
590589
ht_binary_ev, binary_ev = ti._clip(
@@ -612,7 +611,7 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
612611
if order == "K":
613612
buf2 = _empty_like_orderK(a_max, buf2_dt)
614613
else:
615-
buf2 = dpt_ext.empty_like(a_max, dtype=buf2_dt, order=order)
614+
buf2 = dpt.empty_like(a_max, dtype=buf2_dt, order=order)
616615
_manager = SequentialOrderManager[exec_q]
617616
dep_ev = _manager.submitted_events
618617
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
@@ -631,18 +630,18 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
631630
exec_q,
632631
)
633632
else:
634-
out = dpt_ext.empty(
633+
out = dpt.empty(
635634
res_shape,
636635
dtype=res_dt,
637636
usm_type=res_usm_type,
638637
sycl_queue=exec_q,
639638
order=order,
640639
)
641640

642-
x = dpt_ext.broadcast_to(x, res_shape)
641+
x = dpt.broadcast_to(x, res_shape)
643642
if a_min.shape != res_shape:
644-
a_min = dpt_ext.broadcast_to(a_min, res_shape)
645-
buf2 = dpt_ext.broadcast_to(buf2, res_shape)
643+
a_min = dpt.broadcast_to(a_min, res_shape)
644+
buf2 = dpt.broadcast_to(buf2, res_shape)
646645
ht_binary_ev, binary_ev = ti._clip(
647646
src=x,
648647
min=a_min,
@@ -668,7 +667,7 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
668667
if order == "K":
669668
buf1 = _empty_like_orderK(a_min, buf1_dt)
670669
else:
671-
buf1 = dpt_ext.empty_like(a_min, dtype=buf1_dt, order=order)
670+
buf1 = dpt.empty_like(a_min, dtype=buf1_dt, order=order)
672671
_manager = SequentialOrderManager[exec_q]
673672
dep_ev = _manager.submitted_events
674673
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
@@ -687,18 +686,18 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
687686
exec_q,
688687
)
689688
else:
690-
out = dpt_ext.empty(
689+
out = dpt.empty(
691690
res_shape,
692691
dtype=res_dt,
693692
usm_type=res_usm_type,
694693
sycl_queue=exec_q,
695694
order=order,
696695
)
697696

698-
x = dpt_ext.broadcast_to(x, res_shape)
699-
buf1 = dpt_ext.broadcast_to(buf1, res_shape)
697+
x = dpt.broadcast_to(x, res_shape)
698+
buf1 = dpt.broadcast_to(buf1, res_shape)
700699
if a_max.shape != res_shape:
701-
a_max = dpt_ext.broadcast_to(a_max, res_shape)
700+
a_max = dpt.broadcast_to(a_max, res_shape)
702701
ht_binary_ev, binary_ev = ti._clip(
703702
src=x,
704703
min=buf1,
@@ -736,7 +735,7 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
736735
if order == "K":
737736
buf1 = _empty_like_orderK(a_min, buf1_dt)
738737
else:
739-
buf1 = dpt_ext.empty_like(a_min, dtype=buf1_dt, order=order)
738+
buf1 = dpt.empty_like(a_min, dtype=buf1_dt, order=order)
740739

741740
_manager = SequentialOrderManager[exec_q]
742741
dep_evs = _manager.submitted_events
@@ -747,7 +746,7 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
747746
if order == "K":
748747
buf2 = _empty_like_orderK(a_max, buf2_dt)
749748
else:
750-
buf2 = dpt_ext.empty_like(a_max, dtype=buf2_dt, order=order)
749+
buf2 = dpt.empty_like(a_max, dtype=buf2_dt, order=order)
751750
ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
752751
src=a_max, dst=buf2, sycl_queue=exec_q, depends=dep_evs
753752
)
@@ -758,17 +757,17 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
758757
x, buf1, buf2, res_dt, res_shape, res_usm_type, exec_q
759758
)
760759
else:
761-
out = dpt_ext.empty(
760+
out = dpt.empty(
762761
res_shape,
763762
dtype=res_dt,
764763
usm_type=res_usm_type,
765764
sycl_queue=exec_q,
766765
order=order,
767766
)
768767

769-
x = dpt_ext.broadcast_to(x, res_shape)
770-
buf1 = dpt_ext.broadcast_to(buf1, res_shape)
771-
buf2 = dpt_ext.broadcast_to(buf2, res_shape)
768+
x = dpt.broadcast_to(x, res_shape)
769+
buf1 = dpt.broadcast_to(buf1, res_shape)
770+
buf2 = dpt.broadcast_to(buf2, res_shape)
772771
ht_, clip_ev = ti._clip(
773772
src=x,
774773
min=buf1,

0 commit comments

Comments
 (0)