Skip to content

Commit aa1e519

Browse files
Merge branch 'master' into temporary-disable-test-test_minimum_nan-for-float16
2 parents 9b742d9 + da526ca commit aa1e519

4 files changed

Lines changed: 86 additions & 15 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616

1717
### Fixed
1818

19+
* Fixed incorrect in-place advanced indexing for 4D arrays when using `range` or `list` as index keys [#2872](https://github.com/IntelPython/dpnp/pull/2872)
1920
* Fixed `conda build` command syntax in GitHub workflows and documentation to use `conda-build` [#2888](https://github.com/IntelPython/dpnp/pull/2888)
2021

2122
### Security

dpnp/dpnp_array.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838

3939
import warnings
4040

41+
import numpy
42+
4143
import dpnp
4244
import dpnp.tensor as dpt
4345
import dpnp.tensor._type_utils as dtu
@@ -46,24 +48,45 @@
4648
from .exceptions import AxisError
4749

4850

51+
def _unwrap_index_element(x):
52+
"""
53+
Unwrap a single index element for the tensor indexing layer.
54+
55+
Converts dpnp arrays to usm_ndarray and array-like objects (range, list)
56+
to numpy arrays with intp dtype for NumPy-compatible advanced indexing.
57+
58+
"""
59+
60+
if isinstance(x, dpt.usm_ndarray):
61+
return x
62+
if isinstance(x, dpnp_array):
63+
return x.get_array()
64+
if isinstance(x, range):
65+
return numpy.asarray(x, dtype=numpy.intp)
66+
if isinstance(x, list):
67+
# keep boolean lists as boolean
68+
arr = numpy.asarray(x)
69+
# cast empty lists (float64 in NumPy) to intp
70+
# for correct tensor indexing
71+
if arr.size == 0:
72+
arr = arr.astype(numpy.intp)
73+
return arr
74+
return x
75+
76+
4977
def _get_unwrapped_index_key(key):
5078
"""
5179
Get an unwrapped index key.
5280
5381
Return a key where each nested instance of DPNP array is unwrapped into
54-
USM ndarray for further processing in DPCTL advanced indexing functions.
82+
USM ndarray, and array-like objects (range, list) are converted to numpy
83+
arrays for further processing in advanced indexing functions.
5584
5685
"""
5786

5887
if isinstance(key, tuple):
59-
if any(isinstance(x, dpnp_array) for x in key):
60-
# create a new tuple from the input key with unwrapped DPNP arrays
61-
return tuple(
62-
x.get_array() if isinstance(x, dpnp_array) else x for x in key
63-
)
64-
elif isinstance(key, dpnp_array):
65-
return key.get_array()
66-
return key
88+
return tuple(_unwrap_index_element(x) for x in key)
89+
return _unwrap_index_element(key)
6790

6891

6992
# pylint: disable=too-many-public-methods

dpnp/tensor/_slicing.pxi

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,6 @@ cdef bint _is_boolean(object x) except *:
104104
return f in "?"
105105
else:
106106
return False
107-
if callable(getattr(x, "__bool__", None)):
108-
try:
109-
x.__bool__()
110-
except (TypeError, ValueError):
111-
return False
112-
return True
113107
return False
114108

115109

dpnp/tests/test_indexing.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,59 @@ def test_indexing_array_negative_strides(self):
353353
arr[slices] = 10
354354
assert_equal(arr, 10.0, strict=False)
355355

356+
@pytest.mark.parametrize(
357+
"idx",
358+
[
359+
(range(2), range(2)),
360+
([0, 1], [0, 1]),
361+
],
362+
ids=["range", "list"],
363+
)
364+
def test_array_like_index_getitem(self, idx):
365+
np_a = numpy.arange(36).reshape(2, 2, 3, 3)
366+
dp_a = dpnp.arange(36).reshape(2, 2, 3, 3)
367+
assert_array_equal(dp_a[idx], np_a[idx])
368+
369+
@pytest.mark.parametrize(
370+
"idx",
371+
[
372+
(range(2), range(2)),
373+
([0, 1], [0, 1]),
374+
],
375+
ids=["range", "list"],
376+
)
377+
def test_array_like_index_setitem(self, idx):
378+
np_a = numpy.arange(36).reshape(2, 2, 3, 3)
379+
dp_a = dpnp.arange(36).reshape(2, 2, 3, 3)
380+
np_a[idx] = 0
381+
dp_a[idx] = 0
382+
assert_array_equal(dp_a, np_a)
383+
384+
def test_array_like_index_inplace_add(self):
385+
np_a = numpy.arange(36).reshape(2, 2, 3, 3)
386+
dp_a = dpnp.arange(36).reshape(2, 2, 3, 3)
387+
np_tmp = -numpy.ones((2, 3, 3), dtype=numpy.intp)
388+
dp_tmp = -dpnp.ones((2, 3, 3), dtype=numpy.intp)
389+
390+
np_a[range(2), range(2)] += 2 * np_tmp
391+
dp_a[range(2), range(2)] += 2 * dp_tmp
392+
assert_array_equal(dp_a, np_a)
393+
394+
@pytest.mark.parametrize(
395+
"idx",
396+
[
397+
range(2),
398+
[0, 1],
399+
range(0),
400+
[],
401+
],
402+
ids=["range", "list", "empty_range", "empty_list"],
403+
)
404+
def test_array_like_single_index(self, idx):
405+
np_a = numpy.arange(24).reshape(2, 3, 4)
406+
dp_a = dpnp.arange(24).reshape(2, 3, 4)
407+
assert_array_equal(dp_a[idx], np_a[idx])
408+
356409

357410
class TestIx:
358411
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)