1- from typing import Any , Protocol , Self , TypeAlias , final , type_check_only
1+ from typing import Any , Generic , Protocol , Self , TypeAlias , final , type_check_only
22from typing_extensions import TypeAliasType , TypeVar
33
44from ._shape import (
55 Shape ,
66 Shape as Shape0ToN ,
77 Shape0 ,
88 Shape1 ,
9- Shape1_ as Shape1ToN ,
9+ Shape1N as Shape1ToN ,
1010 Shape2 ,
11- Shape2_ as Shape2ToN ,
11+ Shape2N as Shape2ToN ,
1212 Shape3 ,
13- Shape3_ as Shape3ToN ,
13+ Shape3N as Shape3ToN ,
1414 Shape4 ,
15- Shape4_ as Shape4ToN ,
15+ Shape4N as Shape4ToN ,
1616)
1717
1818__all__ = [
19- "Broadcastable" ,
19+ "Broadcasts" ,
20+ "BroadcastsTo" ,
21+ "Rank" ,
2022 "Rank0" ,
21- "Rank0ToN " ,
23+ "Rank0N " ,
2224 "Rank1" ,
23- "Rank1ToN " ,
25+ "Rank1N " ,
2426 "Rank2" ,
25- "Rank2ToN " ,
27+ "Rank2N " ,
2628 "Rank3" ,
27- "Rank3ToN " ,
29+ "Rank3N " ,
2830 "Rank4" ,
29- "Rank4ToN " ,
31+ "Rank4N " ,
3032]
3133
3234###
3335
34- _ToT = TypeVar ("_ToT" , bound = Shape )
35- _ToT_contra = TypeVar ("_ToT_contra" , bound = Shape , contravariant = True )
36- _FromT = TypeVar ("_FromT" , bound = Shape )
37- _FromT_contra = TypeVar ("_FromT_contra" , bound = Shape , default = Any , contravariant = True )
38- _RankT = TypeVar ("_RankT" , bound = _HasShape [Any ], default = Any )
39- _RankT_co = TypeVar ("_RankT_co" , default = Any , covariant = True )
40- _ShapeT_contra = TypeVar ("_ShapeT_contra" , contravariant = True )
41- _ShapeT_co = TypeVar ("_ShapeT_co" , covariant = True , default = _ShapeT_contra )
42-
43- ###
44-
4536_Shape0To0 : TypeAlias = Shape0
4637_Shape0To1 : TypeAlias = _Shape0To0 | Shape1
4738_Shape0To2 : TypeAlias = _Shape0To1 | Shape2
@@ -50,74 +41,134 @@ _Shape0To4: TypeAlias = _Shape0To3 | Shape4
5041
5142###
5243
44+ _ToT = TypeVar ("_ToT" , bound = Shape )
45+ _FromT = TypeVar ("_FromT" , bound = Shape )
46+ _RankT = TypeVar ("_RankT" , bound = Shape , default = Any )
47+
48+ _BroadcastableShape = TypeAliasType (
49+ "_BroadcastableShape" ,
50+ _FromT | _CanBroadcastFrom [_FromT , _RankT ],
51+ type_params = (_FromT , _RankT ),
52+ )
53+
54+ BroadcastsTo = TypeAliasType (
55+ "BroadcastsTo" ,
56+ _HasRank [_CanBroadcastTo [_ToT , _RankT ]],
57+ type_params = (_ToT , _RankT ),
58+ )
59+ Broadcasts = TypeAliasType (
60+ "Broadcasts" ,
61+ _HasRank [_BroadcastableShape [_FromT , _RankT ]],
62+ type_params = (_FromT , _RankT ),
63+ )
64+
65+ ###
66+
67+ _ShapeT_co = TypeVar (
68+ "_ShapeT_co" ,
69+ bound = Shape | _HasOwnShape | _CanBroadcastFrom | _CanBroadcastTo ,
70+ covariant = True ,
71+ )
72+
73+ @type_check_only
74+ class _HasShape (Protocol [_ShapeT_co ]):
75+ @property
76+ def shape (self , / ) -> _ShapeT_co : ...
77+
78+ _ShapeT = TypeVar ("_ShapeT" , bound = Shape )
79+
80+ @final
5381@type_check_only
54- class _CanBroadcast (Protocol [_ToT_contra , _FromT_contra , _RankT_co ]):
55- def __broadcast__ (self , to : _ToT_contra , from_ : _FromT_contra , / ) -> _RankT_co : ...
82+ class _HasRank (Protocol [_ShapeT_co ]):
83+ @property
84+ def shape (self : _HasShape [_ShapeT ], / ) -> _ShapeT : ...
85+
86+ _FromT_contra = TypeVar ("_FromT_contra" , default = Any , contravariant = True )
87+ _ToT_contra = TypeVar ("_ToT_contra" , bound = Shape , default = Any , contravariant = True )
88+ _RankT_co = TypeVar ("_RankT_co" , default = Any , covariant = True )
89+
90+ @final
91+ @type_check_only
92+ class _CanBroadcastFrom (Protocol [_FromT_contra , _RankT_co ]):
93+ def __broadcast_from__ (self , from_ : _FromT_contra , / ) -> _RankT_co : ...
94+
95+ @final
96+ @type_check_only
97+ class _CanBroadcastTo (Protocol [_ToT_contra , _RankT_co ]):
98+ def __broadcast_to__ (self , to : _ToT_contra , / ) -> _RankT_co : ...
5699
57100# This double shape-type parameter is a sneaky way to annotate a doubly-bound nominal type range,
58101# e.g. `_HasShape[Shape2ToN, Shape0ToN]` accepts `Shape2ToN`, `Shape1ToN`, and `Shape0ToN`, but
59102# rejects `Shape3ToN` and `Shape1`. Besides brevity, it also works around several mypy bugs that
60103# are related to "unions vs joins".
104+
105+ _OwnShapeT_contra = TypeVar ("_OwnShapeT_contra" , bound = Shape , default = Any , contravariant = True )
106+ _OwnShapeT_co = TypeVar ("_OwnShapeT_co" , bound = Shape , default = _OwnShapeT_contra , covariant = True )
107+
108+ @final
61109@type_check_only
62- class _HasShape (Protocol [_ShapeT_contra , _ShapeT_co ]):
63- def __shape__ (self , shape : _ShapeT_contra , / ) -> _ShapeT_co : ...
110+ class _HasOwnShape (Protocol [_OwnShapeT_contra , _OwnShapeT_co ]):
111+ def __own_shape__ (self , shape : _OwnShapeT_contra , / ) -> _OwnShapeT_co : ...
64112
65113###
66114# TODO(jorenham): embed the array-like types, e.g. `Sequence[Sequence[T]]`
67115
68- @final
116+ _OwnShapeT = TypeVar ("_OwnShapeT" , bound = tuple [Any , ...], default = Any )
117+
69118@type_check_only
70- class Rank0 (tuple [()], _HasShape [Shape0 , Shape0 ]):
71- def __broadcast__ (self , to : Shape0ToN , from_ : _HasShape [Shape0ToN , Shape0ToN ], / ) -> Self : ...
119+ class _BaseRank (Generic [_FromT_contra , _ToT_contra , _OwnShapeT ]):
120+ def __broadcast_from__ (self , from_ : _FromT_contra , / ) -> Self : ...
121+ def __broadcast_to__ (self , to : _ToT_contra , / ) -> Self : ...
122+ def __own_shape__ (self , shape : _OwnShapeT , / ) -> _OwnShapeT : ...
72123
73124@type_check_only
74- class Rank1 (tuple [int ], _HasShape [Shape1 , Shape1 ]):
75- def __broadcast__ (self , to : Shape1ToN , from_ : _Shape0To1 | _HasShape [Shape1ToN , Shape0ToN ], / ) -> Self : ...
125+ class _BaseRankM (
126+ _BaseRank [_FromT_contra | _HasOwnShape [_ToT_contra , Shape ], _ToT_contra , _OwnShapeT ],
127+ Generic [_FromT_contra , _ToT_contra , _OwnShapeT ],
128+ ): ...
76129
77130@final
78131@type_check_only
79- class Rank2 (tuple [int , int ], _HasShape [Shape2 , Shape2 ]):
80- def __broadcast__ (self , to : Shape2ToN , from_ : _Shape0To2 | _HasShape [Shape2ToN , Shape0ToN ], / ) -> Self : ...
132+ class Rank0 (_BaseRankM [_Shape0To0 , Shape0ToN , Shape0 ], tuple [()]): ...
81133
82134@final
83135@type_check_only
84- class Rank3 (tuple [int , int , int ], _HasShape [Shape3 , Shape3 ]):
85- def __broadcast__ (self , to : Shape3ToN , from_ : _Shape0To3 | _HasShape [Shape3ToN , Shape0ToN ], / ) -> Self : ...
136+ class Rank1 (_BaseRankM [_Shape0To1 , Shape1ToN , Shape1 ], tuple [int ]): ...
86137
87138@final
88139@type_check_only
89- class Rank4 (tuple [int , int , int , int ], _HasShape [Shape4 , Shape4 ]):
90- def __broadcast__ (self , to : Shape4ToN , from_ : _Shape0To4 | _HasShape [Shape4ToN , Shape0ToN ], / ) -> Self : ...
140+ class Rank2 (_BaseRankM [_Shape0To2 , Shape2ToN , Shape2 ], tuple [int , int ]): ...
91141
92- ###
93- # These emulate `AnyOf`, rather than a `Union`.
142+ @final
143+ @type_check_only
144+ class Rank3 (_BaseRankM [_Shape0To3 , Shape3ToN , Shape3 ], tuple [int , int , int ]): ...
94145
95146@final
96147@type_check_only
97- class Rank0ToN (tuple [int , ...], _HasShape [Shape0ToN , Shape0ToN ]):
98- def __broadcast__ (self , to : Shape0ToN , from_ : Shape0ToN , / ) -> Self : ...
148+ class Rank4 (_BaseRankM [_Shape0To4 , Shape4ToN , Shape4 ], tuple [int , int , int , int ]): ...
149+
150+ # this emulates `AnyOf`, rather than a `Union`.
151+ @type_check_only
152+ class _BaseRankMToN (_BaseRank [Shape0ToN , _OwnShapeT , _OwnShapeT ], Generic [_OwnShapeT ]): ...
99153
100154@final
101155@type_check_only
102- class Rank1ToN (tuple [int , * tuple [int , ...]], _HasShape [Shape1ToN , Shape1ToN ]):
103- def __broadcast__ (self , to : Shape1ToN , from_ : Shape0ToN , / ) -> Self : ...
156+ class Rank (_BaseRankMToN [Shape0ToN ], tuple [int , ...]): ...
104157
105158@final
106159@type_check_only
107- class Rank2ToN (tuple [int , int , * tuple [int , ...]], _HasShape [Shape2ToN , Shape2ToN ]):
108- def __broadcast__ (self , to : Shape2ToN , from_ : Shape0ToN , / ) -> Self : ...
160+ class Rank1N (_BaseRankMToN [Shape1ToN ], tuple [int , * tuple [int , ...]]): ...
109161
110162@final
111163@type_check_only
112- class Rank3ToN (tuple [int , int , int , * tuple [int , ...]], _HasShape [Shape3ToN , Shape3ToN ]):
113- def __broadcast__ (self , to : Shape3ToN , from_ : Shape0ToN , / ) -> Self : ...
164+ class Rank2N (_BaseRankMToN [Shape2ToN ], tuple [int , int , * tuple [int , ...]]): ...
114165
115166@final
116167@type_check_only
117- class Rank4ToN (tuple [int , int , int , int , * tuple [int , ...]], _HasShape [Shape4ToN , Shape4ToN ]):
118- def __broadcast__ (self , to : Shape4ToN , from_ : Shape0ToN , / ) -> Self : ...
168+ class Rank3N (_BaseRankMToN [Shape3ToN ], tuple [int , int , int , * tuple [int , ...]]): ...
119169
120- ###
170+ @final
171+ @type_check_only
172+ class Rank4N (_BaseRankMToN [Shape4ToN ], tuple [int , int , int , int , * tuple [int , ...]]): ...
121173
122- Broadcastable = TypeAliasType ("Broadcastable" , _CanBroadcast [_ToT , Any , _RankT ], type_params = (_ToT , _RankT ))
123- Broadcaster = TypeAliasType ("Broadcaster" , _FromT | _CanBroadcast [Any , _FromT , _RankT ], type_params = (_FromT , _RankT ))
174+ Rank0N : TypeAlias = Rank
0 commit comments