@@ -20,6 +20,17 @@ def _get_bigger_dtype(d1, d2):
2020 return np .object
2121
2222
23+ # Map Python scalar types to their default NumPy equivalents.
24+ # NumPy 2.x (NEP 50) no longer accepts raw Python scalars in np.can_cast,
25+ # so we resolve them to a concrete dtype before calling can_cast.
26+ _PYTHON_TYPE_TO_NUMPY_DTYPE = {
27+ float : np .float64 ,
28+ int : np .int64 ,
29+ bool : np .bool_ ,
30+ complex : np .complex128 ,
31+ }
32+
33+
2334def get_dtype (val : Union [np .ndarray , Sequence , Sample ]) -> np .dtype :
2435 """Get the dtype of a non-uniform mixed dtype sequence of samples."""
2536
@@ -133,12 +144,15 @@ def get_incompatible_dtype(
133144 elif samples .size == 1 :
134145 samples = samples .reshape (1 ).tolist ()[0 ]
135146
136- if isinstance (samples , (int , float , bool )) or hasattr (samples , "dtype" ):
137- return (
138- None
139- if np .can_cast (samples , dtype )
140- else getattr (samples , "dtype" , np .array (samples ).dtype )
141- )
147+ py_type = type (samples )
148+ if py_type in _PYTHON_TYPE_TO_NUMPY_DTYPE :
149+ # NumPy 2.x (NEP 50) removed support for passing raw Python scalars to
150+ # np.can_cast. Convert to the equivalent NumPy dtype first.
151+ from_dtype = _PYTHON_TYPE_TO_NUMPY_DTYPE [py_type ]
152+ return None if np .can_cast (from_dtype , dtype , casting = "same_kind" ) else from_dtype
153+ elif hasattr (samples , "dtype" ):
154+ from_dtype = samples .dtype
155+ return None if np .can_cast (from_dtype , dtype , casting = "same_kind" ) else from_dtype
142156 elif isinstance (samples , str ):
143157 return None if dtype == np .dtype (str ) else np .dtype (str )
144158 elif isinstance (samples , Sequence ):
0 commit comments