Skip to content

Commit fdf66de

Browse files
authored
✨ specialized scalar dtype types (#514)
2 parents 2c1cd63 + a654546 commit fdf66de

3 files changed

Lines changed: 120 additions & 27 deletions

File tree

src/numpy-stubs/@test/static/accept/scalars.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ assert_type(c8.squeeze(), np.complex64)
3636
assert_type(c8.byteswap(), np.complex64)
3737
assert_type(c8.transpose(), np.complex64)
3838

39-
assert_type(c8.dtype, np.dtype[np.complex64])
39+
assert_type(c8.dtype, np.dtypes.Complex64DType)
4040

4141
assert_type(c8.real, np.float32)
4242
assert_type(c16.imag, np.float64)

src/numpy-stubs/__init__.pyi

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3947,6 +3947,9 @@ class bool_(generic[_BoolItemT_co], Generic[_BoolItemT_co]):
39473947
#
39483948
@property
39493949
@override
3950+
def dtype(self) -> dtypes.BoolDType: ...
3951+
@property
3952+
@override
39503953
def itemsize(self) -> L[1]: ...
39513954
@property
39523955
@override
@@ -4377,8 +4380,6 @@ class bool_(generic[_BoolItemT_co], Generic[_BoolItemT_co]):
43774380

43784381
bool = bool_
43794382

4380-
# TODO(jorenham): Move these protocols to _numtype
4381-
43824383
class number(
43834384
_CmpOpMixin[_nt.CoComplex_0d, _nt.CoComplex_1nd],
43844385
generic[_NumberItemT_co],
@@ -4704,6 +4705,11 @@ class int8(_IntMixin[L[1]], signedinteger[_n._8]):
47044705
@type_check_only
47054706
def __nep50_rule3__(self, other: uint16, /) -> int32: ...
47064707

4708+
#
4709+
@property
4710+
@override
4711+
def dtype(self) -> dtypes.Int8DType: ...
4712+
47074713
byte = int8
47084714

47094715
class int16(_IntMixin[L[2]], signedinteger[_n._16]):
@@ -4720,6 +4726,11 @@ class int16(_IntMixin[L[2]], signedinteger[_n._16]):
47204726
@type_check_only
47214727
def __nep50_rule3__(self, other: float16, /) -> float32: ...
47224728

4729+
#
4730+
@property
4731+
@override
4732+
def dtype(self) -> dtypes.Int16DType: ...
4733+
47234734
short = int16
47244735

47254736
class int32(_IntMixin[L[4]], signedinteger[_n._32]):
@@ -4739,6 +4750,11 @@ class int32(_IntMixin[L[4]], signedinteger[_n._32]):
47394750
@type_check_only
47404751
def __nep50_rule3__(self, other: _JustComplexFloating, /) -> complexfloating: ...
47414752

4753+
#
4754+
@property
4755+
@override
4756+
def dtype(self) -> dtypes.Int32DType: ...
4757+
47424758
intc = int32
47434759

47444760
class int64(_IntMixin[L[8]], signedinteger[_n._64]):
@@ -4764,6 +4780,11 @@ class int64(_IntMixin[L[8]], signedinteger[_n._64]):
47644780
@type_check_only
47654781
def __nep50_rule5__(self, other: _JustInteger | _JustUnsignedInteger, /) -> Self | float64: ...
47664782

4783+
#
4784+
@property
4785+
@override
4786+
def dtype(self) -> dtypes.Int64DType: ...
4787+
47674788
if sys.platform == "win32":
47684789
long: TypeAlias = int32 # pyright: ignore[reportRedeclaration]
47694790
else:
@@ -4805,6 +4826,11 @@ class uint8(_IntMixin[L[1]], unsignedinteger[_n._8]):
48054826
@type_check_only
48064827
def __nep50_rule5__(self, other: _JustInteger, /) -> integer: ...
48074828

4829+
#
4830+
@property
4831+
@override
4832+
def dtype(self) -> dtypes.UInt8DType: ...
4833+
48084834
ubyte = uint8
48094835

48104836
class uint16(_IntMixin[L[2]], unsignedinteger[_n._16]):
@@ -4827,6 +4853,11 @@ class uint16(_IntMixin[L[2]], unsignedinteger[_n._16]):
48274853
@type_check_only
48284854
def __nep50_rule5__(self, other: _JustInteger, /) -> integer: ...
48294855

4856+
#
4857+
@property
4858+
@override
4859+
def dtype(self) -> dtypes.UInt16DType: ...
4860+
48304861
ushort = uint16
48314862

48324863
class uint32(_IntMixin[L[4]], unsignedinteger[_n._32]):
@@ -4849,6 +4880,11 @@ class uint32(_IntMixin[L[4]], unsignedinteger[_n._32]):
48494880
@type_check_only
48504881
def __nep50_rule5__(self, other: _JustInteger, /) -> integer: ...
48514882

4883+
#
4884+
@property
4885+
@override
4886+
def dtype(self) -> dtypes.UInt32DType: ...
4887+
48524888
uintc = uint32
48534889

48544890
class uint64(_IntMixin[L[8]], unsignedinteger[_n._64]):
@@ -4872,6 +4908,11 @@ class uint64(_IntMixin[L[8]], unsignedinteger[_n._64]):
48724908
@type_check_only
48734909
def __nep50_rule5__(self, other: _JustInteger, /) -> uint64 | float64: ...
48744910

4911+
#
4912+
@property
4913+
@override
4914+
def dtype(self) -> dtypes.UInt64DType: ...
4915+
48754916
if sys.platform == "win32":
48764917
ulong: TypeAlias = uint32 # pyright: ignore[reportRedeclaration]
48774918
else:
@@ -4982,6 +5023,11 @@ class float16(_FloatMixin[L[2]], floating[_n._16]):
49825023
@type_check_only
49835024
def __nep50_rule2__(self, other: _AbstractInteger, /) -> floating: ...
49845025

5026+
#
5027+
@property
5028+
@override
5029+
def dtype(self) -> dtypes.Float16DType: ...
5030+
49855031
half = float16
49865032

49875033
class float32(_FloatMixin[L[4]], floating[_n._32]):
@@ -4996,6 +5042,11 @@ class float32(_FloatMixin[L[4]], floating[_n._32]):
49965042
@type_check_only
49975043
def __nep50_rule2__(self, other: _AbstractInteger, /) -> floating: ...
49985044

5045+
#
5046+
@property
5047+
@override
5048+
def dtype(self) -> dtypes.Float32DType: ...
5049+
49995050
single = float32
50005051

50015052
class float64(_FloatMixin[L[8]], floating[_n._64], float): # type: ignore[misc]
@@ -5016,6 +5067,9 @@ class float64(_FloatMixin[L[8]], floating[_n._64], float): # type: ignore[misc]
50165067
#
50175068
@property
50185069
@override
5070+
def dtype(self) -> dtypes.Float64DType: ...
5071+
@property
5072+
@override
50195073
def real(self) -> Self: ...
50205074
@property
50215075
@override
@@ -5053,6 +5107,11 @@ class longdouble(_FloatMixin[L[12, 16]], floating[_n._64L]):
50535107
@type_check_only
50545108
def __nep50_rule6__(self, other: _JustInexact | _JustNumber, /) -> longdouble | clongdouble: ...
50555109

5110+
#
5111+
@property
5112+
@override
5113+
def dtype(self) -> dtypes.LongDoubleDType: ...
5114+
50565115
#
50575116
@overload
50585117
def item(self, /) -> Self: ...
@@ -5125,6 +5184,9 @@ class complex64(complexfloating[_n._32]):
51255184
#
51265185
@property
51275186
@override
5187+
def dtype(self) -> dtypes.Complex64DType: ...
5188+
@property
5189+
@override
51285190
def itemsize(self) -> L[8]: ...
51295191
@property
51305192
@override
@@ -5166,6 +5228,9 @@ class complex128(complexfloating[_n._64], complex): # type: ignore[misc]
51665228
#
51675229
@property
51685230
@override
5231+
def dtype(self) -> dtypes.Complex128DType: ...
5232+
@property
5233+
@override
51695234
def itemsize(self) -> L[16]: ...
51705235
@property
51715236
@override
@@ -5210,6 +5275,9 @@ class clongdouble(complexfloating[_n._64L]):
52105275
#
52115276
@property
52125277
@override
5278+
def dtype(self) -> dtypes.CLongDoubleDType: ...
5279+
@property
5280+
@override
52135281
def itemsize(self) -> L[24, 32]: ...
52145282
@property
52155283
@override
@@ -5276,6 +5344,10 @@ class object_(_RealMixin, generic[Any]):
52765344
if sys.version_info >= (3, 12):
52775345
def __release_buffer__(self, buffer: memoryview, /) -> None: ...
52785346

5347+
@property
5348+
@override
5349+
def dtype(self) -> dtypes.ObjectDType: ...
5350+
52795351
@final
52805352
class flexible(_RealMixin, generic[_FlexItemT_co], Generic[_FlexItemT_co]): # type: ignore[misc]
52815353
@abc.abstractmethod
@@ -5301,6 +5373,11 @@ class bytes_(character[bytes], bytes): # type: ignore[misc]
53015373
@overload
53025374
def __init__(self, o: object = b"", /) -> None: ...
53035375

5376+
#
5377+
@property
5378+
@override
5379+
def dtype(self) -> dtypes.BytesDType: ...
5380+
53045381
class str_(character[str], str): # type: ignore[misc]
53055382
@type_check_only
53065383
def __nep50__(self, below: str_ | object_, above: Never, /) -> str_: ...
@@ -5317,6 +5394,11 @@ class str_(character[str], str): # type: ignore[misc]
53175394
@overload
53185395
def __init__(self, value: bytes, /, encoding: str = "utf-8", errors: str = "strict") -> None: ...
53195396

5397+
#
5398+
@property
5399+
@override
5400+
def dtype(self) -> dtypes.StrDType: ...
5401+
53205402
class void(flexible[bytes | tuple[Any, ...]]): # type: ignore[misc] # pyright: ignore[reportGeneralTypeIssues]
53215403
@type_check_only
53225404
def __nep50__(self, below: object_, from_: Never, /) -> void: ...
@@ -5334,12 +5416,20 @@ class void(flexible[bytes | tuple[Any, ...]]): # type: ignore[misc] # pyright:
53345416
@override
53355417
def setfield(self, val: ArrayLike, dtype: DTypeLike, offset: int = ...) -> None: ...
53365418

5419+
#
5420+
@property
5421+
@override
5422+
def dtype(self) -> dtypes.VoidDType: ...
5423+
53375424
class datetime64(
53385425
_RealMixin,
53395426
_CmpOpMixin[datetime64, _ArrayLikeDT64_co],
53405427
generic[_DT64ItemT_co],
53415428
Generic[_DT64ItemT_co],
53425429
):
5430+
@property
5431+
@override
5432+
def dtype(self) -> dtypes.DateTime64DType: ...
53435433
@property
53445434
@override
53455435
def itemsize(self) -> L[8]: ...
@@ -5481,6 +5571,9 @@ class timedelta64(
54815571
#
54825572
@property
54835573
@override
5574+
def dtype(self) -> dtypes.TimeDelta64DType: ...
5575+
@property
5576+
@override
54845577
def itemsize(self) -> L[8]: ...
54855578
@property
54865579
@override

src/numpy-stubs/_core/_type_aliases.pyi

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,30 @@ import numpy as np
77

88
@type_check_only
99
class _CNamesDict(TypedDict):
10-
BOOL: np.dtype[np.bool]
11-
HALF: np.dtype[np.half]
12-
FLOAT: np.dtype[np.single]
13-
DOUBLE: np.dtype[np.double]
14-
LONGDOUBLE: np.dtype[np.longdouble]
15-
CFLOAT: np.dtype[np.csingle]
16-
CDOUBLE: np.dtype[np.cdouble]
17-
CLONGDOUBLE: np.dtype[np.clongdouble]
18-
STRING: np.dtype[np.bytes_]
19-
UNICODE: np.dtype[np.str_]
20-
VOID: np.dtype[np.void]
21-
OBJECT: np.dtype[np.object_]
22-
DATETIME: np.dtype[np.datetime64]
23-
TIMEDELTA: np.dtype[np.timedelta64]
24-
BYTE: np.dtype[np.byte]
25-
UBYTE: np.dtype[np.ubyte]
26-
SHORT: np.dtype[np.short]
27-
USHORT: np.dtype[np.ushort]
28-
INT: np.dtype[np.intc]
29-
UINT: np.dtype[np.uintc]
30-
LONG: np.dtype[np.long]
31-
ULONG: np.dtype[np.ulong]
32-
LONGLONG: np.dtype[np.longlong]
33-
ULONGLONG: np.dtype[np.ulonglong]
10+
BOOL: np.dtypes.BoolDType
11+
BYTE: np.dtypes.ByteDType
12+
UBYTE: np.dtypes.UByteDType
13+
SHORT: np.dtypes.ShortDType
14+
USHORT: np.dtypes.UShortDType
15+
INT: np.dtypes.IntDType
16+
UINT: np.dtypes.UIntDType
17+
LONG: np.dtypes.LongDType
18+
ULONG: np.dtypes.ULongDType
19+
LONGLONG: np.dtypes.LongLongDType
20+
ULONGLONG: np.dtypes.ULongLongDType
21+
HALF: np.dtypes.Float16DType
22+
FLOAT: np.dtypes.Float32DType
23+
DOUBLE: np.dtypes.Float64DType
24+
LONGDOUBLE: np.dtypes.LongDoubleDType
25+
CFLOAT: np.dtypes.Complex64DType
26+
CDOUBLE: np.dtypes.Complex128DType
27+
CLONGDOUBLE: np.dtypes.CLongDoubleDType
28+
STRING: np.dtypes.BytesDType
29+
UNICODE: np.dtypes.StrDType
30+
VOID: np.dtypes.VoidDType
31+
OBJECT: np.dtypes.ObjectDType
32+
DATETIME: np.dtypes.DateTime64DType
33+
TIMEDELTA: np.dtypes.TimeDelta64DType
3434

3535
_AbstractTypeName: TypeAlias = L[
3636
"generic",

0 commit comments

Comments
 (0)