2323 TypeVar ,
2424 cast ,
2525 final ,
26+ get_args ,
2627 overload ,
2728)
2829from typing_extensions import override
6263 "xor" ,
6364]
6465_UnOpName : TypeAlias = Literal ["abs" , "neg" , "pos" , "invert" ]
65- _OpName : TypeAlias = Literal [_BinOpName , _UnOpName ]
66+ _OpName : TypeAlias = Literal [_UnOpName , _BinOpName ]
6667
6768###
6869
69- ROOT_DIR : Final = Path (__file__ ).parent .parent
70- TARGET_DIR : Final = ROOT_DIR / "src" / "numpy-stubs" / "@test" / "generated"
70+ DIR_ROOT : Final = Path (__file__ ).parent .parent
71+ DIR_SRC : Final = DIR_ROOT / "src"
72+ DIRS_TARGET : Final = {
73+ dir_package .stem : dir_package / "@test" / "generated"
74+ for dir_package in DIR_SRC .iterdir ()
75+ if dir_package .is_dir ()
76+ }
7177
7278TAB : Final = " " * 4
7379BR : Final = "\n "
@@ -366,6 +372,7 @@ def _strip_preamble(source: str) -> tuple[str | None, str]:
366372
367373
368374class TestGen (abc .ABC ):
375+ package : ClassVar [str ]
369376 stdlib_imports : ClassVar [tuple [str , ...]] = ("from typing import assert_type" ,)
370377 numpy_imports : ClassVar [tuple [str , ...]] = (f"import numpy as { NP } " ,)
371378
@@ -383,7 +390,7 @@ def __init__(self) -> None:
383390 @property
384391 def path (self ) -> Path :
385392 assert self .testname
386- return TARGET_DIR / f"{ self .testname } .pyi"
393+ return DIRS_TARGET [ self . package ] / f"{ self .testname } .pyi"
387394
388395 def get_names (self ) -> Iterable [tuple [str , str ]]:
389396 return ()
@@ -411,7 +418,7 @@ def _generate_section(self, /, *lines: str) -> Generator[str]:
411418
412419 def _generate_preamble (self ) -> Generator [str ]:
413420 timestamp = f"{ np .datetime64 ('now' )} Z"
414- here = Path (__file__ ).relative_to (ROOT_DIR )
421+ here = Path (__file__ ).relative_to (DIR_ROOT )
415422
416423 yield f"# { PREAMBLE_PREFIX } { timestamp } with { here } "
417424
@@ -492,7 +499,7 @@ def regenerate(self, /, *, always: bool = False) -> Iterator[str]:
492499 head_new , body_new = _strip_preamble (src_new )
493500 assert head_new , src_new
494501
495- path_new = str (self .path .relative_to (ROOT_DIR ))
502+ path_new = str (self .path .relative_to (DIR_ROOT ))
496503 date_new = head_new .split (" " , 1 )[0 ]
497504
498505 if src_old := self ._read ():
@@ -516,13 +523,14 @@ def regenerate(self, /, *, always: bool = False) -> Iterator[str]:
516523 tofile = path_new ,
517524 fromfiledate = date_old ,
518525 tofiledate = date_new if write else date_old ,
519- n = 0 ,
526+ n = 1 ,
520527 lineterm = BR ,
521528 )
522529
523530
524531@final
525532class EMath (TestGen ):
533+ package = "numpy-stubs"
526534 testname = "emath"
527535
528536 VALUES : Final [dict [str , list [Any ]]] = {
@@ -735,6 +743,7 @@ def get_testcases(self) -> Iterable[str | None]:
735743
736744@final
737745class LiteralBoolOps (TestGen ):
746+ package = "numpy-stubs"
738747 testname = "literal_bool_ops"
739748
740749 UNOPS : ClassVar = {
@@ -887,6 +896,7 @@ def get_testcases(self) -> Iterable[str | None]:
887896
888897@final
889898class ScalarOps (TestGen ):
899+ package = "numpy-stubs"
890900 testname = "scalar_ops_{}"
891901
892902 OPS_ARITHMETIC : ClassVar [dict [str , _BinOp ]] = {
@@ -1144,6 +1154,7 @@ def get_testcases(self) -> Iterable[str | None]:
11441154
11451155
11461156class NDArrayOps (TestGen ):
1157+ package = "numpy-stubs"
11471158 testname = "ndarray_{}"
11481159 numpy_imports_extra : tuple [str , ...] = ("import _numtype as _nt" ,)
11491160
@@ -1527,36 +1538,52 @@ def get_testcases(self) -> Iterable[str | None]:
15271538TESTGENS : Final [Sequence [TestGen ]] = [
15281539 EMath (binary = False ),
15291540 LiteralBoolOps (),
1530- ScalarOps ("arithmetic" ),
1531- ScalarOps ("modular" ),
1532- ScalarOps ("bitwise" ),
1533- ScalarOps ("comparison" ),
1534- NDArrayOps ("pos" ),
1535- NDArrayOps ("neg" ),
1536- NDArrayOps ("abs" ),
1537- NDArrayOps ("invert" ),
1538- NDArrayOps ("add" ),
1539- NDArrayOps ("sub" ),
1540- NDArrayOps ("mul" ),
1541- NDArrayOps ("matmul" ),
1542- NDArrayOps ("pow" ),
1543- NDArrayOps ("truediv" ),
1544- NDArrayOps ("floordiv" ),
1545- NDArrayOps ("mod" ),
1546- NDArrayOps ("divmod" ),
1547- NDArrayOps ("lshift" ),
1548- NDArrayOps ("rshift" ),
1549- NDArrayOps ("and" ),
1550- NDArrayOps ("xor" ),
1551- NDArrayOps ("or" ),
1541+ * (ScalarOps (op_kind ) for op_kind in get_args (_BinOpKind )),
1542+ * (NDArrayOps (op_name ) for op_name in get_args (_OpName )),
15521543]
15531544
15541545
15551546@np .errstate (all = "ignore" )
15561547def main () -> None :
1557- """(Re)generate the `src/numpy-stubs/@test/generated/{}.pyi` type-tests."""
1548+ """(Re)generate the `src/*/@test/generated/{}.pyi` type-tests."""
1549+ cwd = Path .cwd ()
1550+ paths : dict [str , dict [Path , bool ]] = {}
1551+
15581552 for testgen in TESTGENS :
1559- sys .stdout .writelines (testgen .regenerate ())
1553+ path = testgen .path
1554+ diff = testgen .regenerate ()
1555+ diff_out , diff_check = itertools .tee (diff , 2 )
1556+ sys .stderr .writelines (diff_out )
1557+ sys .stderr .write ("\n " )
1558+ sys .stderr .flush ()
1559+
1560+ diff_count = sum (1 for _ in diff_check )
1561+ if not diff_count :
1562+ sys .stdout .write (f"skipped ./{ path .relative_to (cwd )} \n " )
1563+ sys .stdout .flush ()
1564+
1565+ package_paths = paths .setdefault (testgen .package , {})
1566+ assert path not in package_paths , path
1567+ package_paths [path ] = bool (diff_count )
1568+
1569+ orphans : list [Path ] = []
1570+ for package , testdir in DIRS_TARGET .items ():
1571+ if not testdir .exists ():
1572+ continue
1573+ assert testdir .is_dir ()
1574+
1575+ known = paths .get (package , {})
1576+ for path in testdir .rglob ("*.pyi" ):
1577+ assert path .is_file ()
1578+ if path not in known :
1579+ orphans .append (path )
1580+
1581+ for orphan in orphans :
1582+ assert orphan .is_file ()
1583+ orphan .unlink ()
1584+
1585+ sys .stderr .write (f"removed ./{ orphan .relative_to (cwd )} \n " )
1586+ sys .stderr .flush ()
15601587
15611588
15621589if __name__ == "__main__" :
0 commit comments