Skip to content

Commit b3e9b44

Browse files
Cristian GarciaFlax Authors
authored andcommitted
remove abstracted_axes from nnx.jit
PiperOrigin-RevId: 842321273
1 parent 2204c82 commit b3e9b44

1 file changed

Lines changed: 0 additions & 7 deletions

File tree

flax/nnx/transforms/compilation.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ def jit(
158158
device: tp.Optional[jax.Device] = None,
159159
backend: tp.Optional[str] = None,
160160
inline: bool = False,
161-
abstracted_axes: tp.Optional[tp.Any] = None,
162161
) -> tp.Callable[[tp.Callable[P, R]], JitWrapped[P, R]]: ...
163162
@tp.overload
164163
def jit(
@@ -174,7 +173,6 @@ def jit(
174173
device: tp.Optional[jax.Device] = None,
175174
backend: tp.Optional[str] = None,
176175
inline: bool = False,
177-
abstracted_axes: tp.Optional[tp.Any] = None,
178176
) -> JitWrapped[P, R]: ...
179177
def jit(
180178
fun: tp.Callable[P, R] | Missing = MISSING,
@@ -189,7 +187,6 @@ def jit(
189187
device: tp.Optional[jax.Device] = None,
190188
backend: tp.Optional[str] = None,
191189
inline: bool = False,
192-
abstracted_axes: tp.Optional[tp.Any] = None,
193190
) -> JitWrapped[P, R] | tp.Callable[[tp.Callable[P, R]], JitWrapped[P, R]]:
194191
"""
195192
Lifted version of ``jax.jit`` that can handle Modules / graph nodes as
@@ -342,7 +339,6 @@ def jit(
342339
device=device,
343340
backend=backend,
344341
inline=inline,
345-
abstracted_axes=abstracted_axes,
346342
) # type: ignore[return-value]
347343
# Detect bound nnx.Module methods and raise error.
348344
fun_unbound, _, was_bound = _resolve_bound_callable(fun)
@@ -361,7 +357,6 @@ def jit(
361357
device=device,
362358
backend=backend,
363359
inline=inline,
364-
abstracted_axes=abstracted_axes,
365360
)
366361

367362

@@ -387,7 +382,6 @@ def __init__(
387382
device: tp.Optional[jax.Device] = None,
388383
backend: tp.Optional[str] = None,
389384
inline: bool = False,
390-
abstracted_axes: tp.Optional[tp.Any] = None,
391385
):
392386
functools.update_wrapper(self, fun)
393387
self.fun: tp.Callable[P, R] = fun
@@ -432,7 +426,6 @@ def __init__(
432426
device=device,
433427
backend=backend,
434428
inline=inline,
435-
abstracted_axes=abstracted_axes,
436429
)
437430
self.in_shardings = in_shardings
438431
self.out_shardings = out_shardings

0 commit comments

Comments
 (0)