Skip to content

Commit f324df3

Browse files
committed
Fix handling of Python scalars in get_incompatible_dtype for NumPy 2.x compatibility
1 parent f6b39f8 commit f324df3

1 file changed

Lines changed: 20 additions & 6 deletions

File tree

deeplake/util/casting.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@ def _get_bigger_dtype(d1, d2):
2020
return np.object
2121

2222

23+
# Map Python scalar types to their default NumPy equivalents.
24+
# NumPy 2.x (NEP 50) no longer accepts raw Python scalars in np.can_cast,
25+
# so we resolve them to a concrete dtype before calling can_cast.
26+
_PYTHON_TYPE_TO_NUMPY_DTYPE = {
27+
float: np.float64,
28+
int: np.int64,
29+
bool: np.bool_,
30+
complex: np.complex128,
31+
}
32+
33+
2334
def get_dtype(val: Union[np.ndarray, Sequence, Sample]) -> np.dtype:
2435
"""Get the dtype of a non-uniform mixed dtype sequence of samples."""
2536

@@ -133,12 +144,15 @@ def get_incompatible_dtype(
133144
elif samples.size == 1:
134145
samples = samples.reshape(1).tolist()[0]
135146

136-
if isinstance(samples, (int, float, bool)) or hasattr(samples, "dtype"):
137-
return (
138-
None
139-
if np.can_cast(samples, dtype)
140-
else getattr(samples, "dtype", np.array(samples).dtype)
141-
)
147+
py_type = type(samples)
148+
if py_type in _PYTHON_TYPE_TO_NUMPY_DTYPE:
149+
# NumPy 2.x (NEP 50) removed support for passing raw Python scalars to
150+
# np.can_cast. Convert to the equivalent NumPy dtype first.
151+
from_dtype = _PYTHON_TYPE_TO_NUMPY_DTYPE[py_type]
152+
return None if np.can_cast(from_dtype, dtype, casting="same_kind") else from_dtype
153+
elif hasattr(samples, "dtype"):
154+
from_dtype = samples.dtype
155+
return None if np.can_cast(from_dtype, dtype, casting="same_kind") else from_dtype
142156
elif isinstance(samples, str):
143157
return None if dtype == np.dtype(str) else np.dtype(str)
144158
elif isinstance(samples, Sequence):

0 commit comments

Comments
 (0)