Skip to content

Commit d539ce7

Browse files
author
Flax Team
committed
Add nnx.make_jaxpr.
This change introduces `nnx.make_jaxpr`, a version of `jax.make_jaxpr`.`graph_updates=True` is not supported and will raise an error. PiperOrigin-RevId: 889901203
1 parent 35d911a commit d539ce7

3 files changed

Lines changed: 125 additions & 0 deletions

File tree

flax/nnx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@
195195
from .transforms.transforms import cond as cond
196196
from .transforms.transforms import switch as switch
197197
from .transforms.transforms import checkify as checkify
198+
from .transforms.transforms import make_jaxpr as make_jaxpr
198199
from .transforms.iteration import while_loop as while_loop
199200
from .transforms.iteration import fori_loop as fori_loop
200201
from .transforms.iteration import StateAxes as StateAxes

flax/nnx/transforms/transforms.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,101 @@ def checkify_wrapper(*args, **kwargs):
466466
return checkify_wrapper # type: ignore
467467

468468

469+
@dataclasses.dataclass(eq=False)
470+
class SimpleMakeJaxprFn:
471+
f: tp.Callable[..., tp.Any]
472+
graph: bool
473+
474+
def __post_init__(self):
475+
functools.update_wrapper(self, self.f, updated=())
476+
477+
@extract.treemap_copy_args
478+
def __call__(self, *args, **kwargs):
479+
if self.graph:
480+
args, kwargs = extract.from_tree2((args, kwargs))
481+
out = self.f(*args, **kwargs)
482+
if self.graph:
483+
out = extract.to_tree2(out)
484+
extract.check_no_aliases('make_jaxpr', args=args, kwargs=kwargs, out=out)
485+
return out
486+
487+
488+
@tp.overload
489+
def make_jaxpr(
490+
f: tp.Callable[..., A],
491+
*,
492+
graph: bool | None = None,
493+
graph_updates: bool | None = None,
494+
static_argnums: int | tp.Sequence[int] = (),
495+
) -> tp.Callable[..., tp.Any]: ...
496+
497+
@tp.overload
498+
def make_jaxpr(
499+
*,
500+
graph: bool | None = None,
501+
graph_updates: bool | None = None,
502+
static_argnums: int | tp.Sequence[int] = (),
503+
) -> tp.Callable[[F], tp.Callable[..., tp.Any]]: ...
504+
505+
def make_jaxpr(
506+
f: tp.Callable[..., A] | Missing = MISSING,
507+
*,
508+
graph: bool | None = None,
509+
graph_updates: bool | None = None,
510+
static_argnums: int | tp.Sequence[int] = (),
511+
) -> tp.Callable[..., tp.Any] | tp.Callable[[F], tp.Callable[..., tp.Any]]:
512+
"""A "lifted" version of `jax.make_jaxpr <https://jax.readthedocs.io/en/latest/jaxpr.html>`_
513+
that can handle `flax.nnx.Module <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module>`_
514+
/ graph nodes as arguments.
515+
516+
Args:
517+
f: the function to be transformed into a Jaxpr.
518+
graph: If ``True`` (default), uses graph-mode which supports the full
519+
NNX feature set including shared references and reference semantics.
520+
If ``False``, uses tree-mode which treats Modules as regular JAX
521+
pytrees, avoiding the overhead of the graph protocol.
522+
graph_updates: If ``True``, propagates updates on graph structure
523+
that happen inside the transform to the input graphs, has no
524+
effect when ``graph=False``. ``nnx.make_jaxpr`` raises an error
525+
if ``graph_updates=True``.
526+
static_argnums: Optional, int or sequence of ints. Specifies which
527+
positional argument(s) to treat as static (compile-time constant).
528+
"""
529+
if isinstance(f, Missing):
530+
return functools.partial(
531+
make_jaxpr,
532+
graph=graph,
533+
graph_updates=graph_updates,
534+
static_argnums=static_argnums,
535+
)
536+
537+
if graph_updates is None:
538+
graph_updates = graphlib.set_graph_updates.current_value()
539+
if graph_updates:
540+
raise ValueError('nnx.make_jaxpr does not support graph_updates=True.')
541+
542+
f_call, _, was_bound = _resolve_bound_callable(f)
543+
if was_bound:
544+
_raise_bound_method_error('make_jaxpr')
545+
if graph is None:
546+
graph = graphlib.set_graph_mode.current_value()
547+
548+
jaxpr_maker = jax.make_jaxpr(
549+
SimpleMakeJaxprFn(f_call, graph=graph),
550+
static_argnums=static_argnums,
551+
)
552+
553+
@functools.wraps(f)
554+
def jaxpr_wrapper(*args, **kwargs):
555+
if graph:
556+
args, kwargs = extract.to_tree2((args, kwargs))
557+
extract.check_no_aliases('make_jaxpr', args=args, kwargs=kwargs)
558+
jaxpr = jaxpr_maker(*args, **kwargs)
559+
return jaxpr
560+
561+
return jaxpr_wrapper
562+
563+
469564
@dataclasses.dataclass(eq=False)
470565
class SimpleCondFn:
471566
f: tp.Callable[..., tp.Any]

tests/nnx/transforms_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5224,6 +5224,35 @@ def f(c):
52245224
np.testing.assert_allclose(out, 1)
52255225

52265226

5227+
class TestMakeJaxpr(parameterized.TestCase):
5228+
5229+
def test_make_jaxpr_graph_updates_error(self):
5230+
m = nnx.Dict(a=nnx.Param(jnp.array(1)))
5231+
5232+
def f(m):
5233+
return m['a'][...]
5234+
5235+
with self.assertRaisesRegex(
5236+
ValueError, 'nnx.make_jaxpr does not support graph_updates=True.'
5237+
):
5238+
nnx.make_jaxpr(f, graph=True, graph_updates=True)(m)
5239+
5240+
@parameterized.parameters(True, False)
5241+
def test_make_jaxpr_with_variable_update(self, graph):
5242+
class Counter(nnx.Module):
5243+
def __init__(self):
5244+
self.count = nnx.Variable(jnp.array(0))
5245+
5246+
def __call__(self):
5247+
self.count[...] += 1
5248+
return self.count[...]
5249+
5250+
m = Counter()
5251+
jaxpr = nnx.make_jaxpr(lambda m: m(), graph=graph, graph_updates=False)(m)
5252+
self.assertIsNotNone(jaxpr)
5253+
self.assertEqual(m.count[...], 0)
5254+
5255+
52275256
class TestBoundMethodTransforms(parameterized.TestCase):
52285257
def test_remat_with_bound_method_raises(self):
52295258
class M(nnx.Module):

0 commit comments

Comments
 (0)