@@ -158,7 +158,6 @@ def jit(
158158 device : tp .Optional [jax .Device ] = None ,
159159 backend : tp .Optional [str ] = None ,
160160 inline : bool = False ,
161- abstracted_axes : tp .Optional [tp .Any ] = None ,
162161) -> tp .Callable [[tp .Callable [P , R ]], JitWrapped [P , R ]]: ...
163162@tp .overload
164163def jit (
@@ -174,7 +173,6 @@ def jit(
174173 device : tp .Optional [jax .Device ] = None ,
175174 backend : tp .Optional [str ] = None ,
176175 inline : bool = False ,
177- abstracted_axes : tp .Optional [tp .Any ] = None ,
178176) -> JitWrapped [P , R ]: ...
179177def jit (
180178 fun : tp .Callable [P , R ] | Missing = MISSING ,
@@ -189,7 +187,6 @@ def jit(
189187 device : tp .Optional [jax .Device ] = None ,
190188 backend : tp .Optional [str ] = None ,
191189 inline : bool = False ,
192- abstracted_axes : tp .Optional [tp .Any ] = None ,
193190) -> JitWrapped [P , R ] | tp .Callable [[tp .Callable [P , R ]], JitWrapped [P , R ]]:
194191 """
195192 Lifted version of ``jax.jit`` that can handle Modules / graph nodes as
@@ -342,7 +339,6 @@ def jit(
342339 device = device ,
343340 backend = backend ,
344341 inline = inline ,
345- abstracted_axes = abstracted_axes ,
346342 ) # type: ignore[return-value]
347343 # Detect bound nnx.Module methods and raise error.
348344 fun_unbound , _ , was_bound = _resolve_bound_callable (fun )
@@ -361,7 +357,6 @@ def jit(
361357 device = device ,
362358 backend = backend ,
363359 inline = inline ,
364- abstracted_axes = abstracted_axes ,
365360 )
366361
367362
@@ -387,7 +382,6 @@ def __init__(
387382 device : tp .Optional [jax .Device ] = None ,
388383 backend : tp .Optional [str ] = None ,
389384 inline : bool = False ,
390- abstracted_axes : tp .Optional [tp .Any ] = None ,
391385 ):
392386 functools .update_wrapper (self , fun )
393387 self .fun : tp .Callable [P , R ] = fun
@@ -432,7 +426,6 @@ def __init__(
432426 device = device ,
433427 backend = backend ,
434428 inline = inline ,
435- abstracted_axes = abstracted_axes ,
436429 )
437430 self .in_shardings = in_shardings
438431 self .out_shardings = out_shardings
0 commit comments