@@ -466,6 +466,101 @@ def checkify_wrapper(*args, **kwargs):
466466 return checkify_wrapper # type: ignore
467467
468468
469+ @dataclasses .dataclass (eq = False )
470+ class SimpleMakeJaxprFn :
471+ f : tp .Callable [..., tp .Any ]
472+ graph : bool
473+
474+ def __post_init__ (self ):
475+ functools .update_wrapper (self , self .f , updated = ())
476+
477+ @extract .treemap_copy_args
478+ def __call__ (self , * args , ** kwargs ):
479+ if self .graph :
480+ args , kwargs = extract .from_tree2 ((args , kwargs ))
481+ out = self .f (* args , ** kwargs )
482+ if self .graph :
483+ out = extract .to_tree2 (out )
484+ extract .check_no_aliases ('make_jaxpr' , args = args , kwargs = kwargs , out = out )
485+ return out
486+
487+
488+ @tp .overload
489+ def make_jaxpr (
490+ f : tp .Callable [..., A ],
491+ * ,
492+ graph : bool | None = None ,
493+ graph_updates : bool | None = None ,
494+ static_argnums : int | tp .Sequence [int ] = (),
495+ ) -> tp .Callable [..., tp .Any ]: ...
496+
497+ @tp .overload
498+ def make_jaxpr (
499+ * ,
500+ graph : bool | None = None ,
501+ graph_updates : bool | None = None ,
502+ static_argnums : int | tp .Sequence [int ] = (),
503+ ) -> tp .Callable [[F ], tp .Callable [..., tp .Any ]]: ...
504+
505+ def make_jaxpr (
506+ f : tp .Callable [..., A ] | Missing = MISSING ,
507+ * ,
508+ graph : bool | None = None ,
509+ graph_updates : bool | None = None ,
510+ static_argnums : int | tp .Sequence [int ] = (),
511+ ) -> tp .Callable [..., tp .Any ] | tp .Callable [[F ], tp .Callable [..., tp .Any ]]:
512+ """A "lifted" version of `jax.make_jaxpr <https://jax.readthedocs.io/en/latest/jaxpr.html>`_
513+ that can handle `flax.nnx.Module <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module>`_
514+ / graph nodes as arguments.
515+
516+ Args:
517+ f: the function to be transformed into a Jaxpr.
518+ graph: If ``True`` (default), uses graph-mode which supports the full
519+ NNX feature set including shared references and reference semantics.
520+ If ``False``, uses tree-mode which treats Modules as regular JAX
521+ pytrees, avoiding the overhead of the graph protocol.
522+ graph_updates: If ``True``, propagates updates on graph structure
523+ that happen inside the transform to the input graphs, has no
524+ effect when ``graph=False``. ``nnx.make_jaxpr`` raises an error
525+ if ``graph_updates=True``.
526+ static_argnums: Optional, int or sequence of ints. Specifies which
527+ positional argument(s) to treat as static (compile-time constant).
528+ """
529+ if isinstance (f , Missing ):
530+ return functools .partial (
531+ make_jaxpr ,
532+ graph = graph ,
533+ graph_updates = graph_updates ,
534+ static_argnums = static_argnums ,
535+ )
536+
537+ if graph_updates is None :
538+ graph_updates = graphlib .set_graph_updates .current_value ()
539+ if graph_updates :
540+ raise ValueError ('nnx.make_jaxpr does not support graph_updates=True.' )
541+
542+ f_call , _ , was_bound = _resolve_bound_callable (f )
543+ if was_bound :
544+ _raise_bound_method_error ('make_jaxpr' )
545+ if graph is None :
546+ graph = graphlib .set_graph_mode .current_value ()
547+
548+ jaxpr_maker = jax .make_jaxpr (
549+ SimpleMakeJaxprFn (f_call , graph = graph ),
550+ static_argnums = static_argnums ,
551+ )
552+
553+ @functools .wraps (f )
554+ def jaxpr_wrapper (* args , ** kwargs ):
555+ if graph :
556+ args , kwargs = extract .to_tree2 ((args , kwargs ))
557+ extract .check_no_aliases ('make_jaxpr' , args = args , kwargs = kwargs )
558+ jaxpr = jaxpr_maker (* args , ** kwargs )
559+ return jaxpr
560+
561+ return jaxpr_wrapper
562+
563+
469564@dataclasses .dataclass (eq = False )
470565class SimpleCondFn :
471566 f : tp .Callable [..., tp .Any ]
0 commit comments