Skip to content

Commit b8aa99d

Browse files
Cristian GarciaFlax Authors
authored andcommitted
remove fingerprint
PiperOrigin-RevId: 842343114
1 parent b3e9b44 commit b8aa99d

2 files changed

Lines changed: 29 additions & 243 deletions

File tree

flax/nnx/graph.py

Lines changed: 8 additions & 205 deletions
Original file line numberDiff line numberDiff line change
@@ -865,207 +865,6 @@ def make_mutable_arraydef(value: variablelib.Ref):
865865
return
866866

867867

868-
@dataclasses.dataclass(slots=True)
869-
class FingerprintContext:
870-
next_index: int
871-
872-
873-
# TODO(cgarciae): the actual fingerprint object is not being used,
874-
# only the traversal process is still relevant
875-
def fingerprint(
876-
node,
877-
/,
878-
*,
879-
ref_index: RefMap | None = None,
880-
new_ref_index: RefMap | None = None,
881-
) -> list[tp.Hashable]:
882-
""" """
883-
if ref_index is None:
884-
ref_index = RefMap()
885-
886-
if new_ref_index is None:
887-
new_ref_index = RefMap()
888-
node_impl = get_node_impl(node)
889-
if node_impl is None:
890-
raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
891-
ctx = FingerprintContext(len(ref_index) + len(new_ref_index))
892-
fp: list[tp.Hashable] = []
893-
_graph_fingerprint(ctx, fp.append, node, node_impl, ref_index, new_ref_index)
894-
return fp
895-
896-
897-
def _graph_fingerprint(
898-
ctx: FingerprintContext,
899-
append_fn: tp.Callable[[tp.Any], None],
900-
node,
901-
node_impl: NodeImpl[Node, Leaf, AuxData],
902-
ref_index: RefMap,
903-
new_ref_index: RefMap,
904-
):
905-
is_pytree_node_ = type(node_impl) is PytreeNodeImpl
906-
is_graph_node_ = type(node_impl) is GraphNodeImpl
907-
908-
append_fn(type(node))
909-
910-
if is_graph_node_:
911-
append_fn(id(node))
912-
if node in ref_index:
913-
append_fn(ref_index[node])
914-
return
915-
elif node in new_ref_index:
916-
append_fn(new_ref_index[node])
917-
return
918-
index = new_ref_index[node] = ctx.next_index
919-
ctx.next_index += 1
920-
else:
921-
index = -1
922-
923-
values, metadata = node_impl.flatten(node)
924-
925-
append_fn(index)
926-
append_fn(metadata)
927-
928-
for key, value in values:
929-
value_node_impl = get_node_impl(value)
930-
append_fn(key)
931-
if value_node_impl is not None:
932-
_graph_fingerprint(
933-
ctx,
934-
append_fn,
935-
value,
936-
value_node_impl,
937-
ref_index,
938-
new_ref_index,
939-
)
940-
elif isinstance(value, Variable):
941-
append_fn(id(value))
942-
append_fn(type(value))
943-
if value in ref_index:
944-
append_fn(ref_index[value])
945-
elif value in new_ref_index:
946-
append_fn(new_ref_index[value])
947-
else:
948-
variable_index = new_ref_index[value] = ctx.next_index
949-
ctx.next_index += 1
950-
append_fn(variable_index)
951-
for key_value in value.get_metadata().items():
952-
append_fn(key_value)
953-
elif not isinstance(value, (jax.Array, np.ndarray)):
954-
append_fn(value)
955-
956-
957-
def check_fingerprint(
958-
node,
959-
fp: list[tp.Hashable],
960-
/,
961-
*,
962-
ref_index: RefMap | None = None,
963-
new_ref_index: RefMap | None = None,
964-
) -> bool:
965-
""" """
966-
if ref_index is None:
967-
ref_index = RefMap()
968-
969-
if new_ref_index is None:
970-
new_ref_index = RefMap()
971-
node_impl = get_node_impl(node)
972-
if node_impl is None:
973-
raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
974-
ctx = FingerprintContext(len(ref_index) + len(new_ref_index))
975-
fp_matches = _check_graph_fingerprint(
976-
ctx, iter(fp), node, node_impl, ref_index, new_ref_index
977-
)
978-
return fp_matches
979-
980-
981-
def _check_graph_fingerprint(
982-
ctx: FingerprintContext,
983-
fp_iterator: tp.Iterator[tp.Hashable],
984-
node,
985-
node_impl: NodeImpl[Node, Leaf, AuxData],
986-
ref_index: RefMap,
987-
new_ref_index: RefMap,
988-
) -> bool:
989-
is_pytree_node_ = type(node_impl) is PytreeNodeImpl
990-
is_graph_node_ = type(node_impl) is GraphNodeImpl
991-
992-
if type(node) != next(fp_iterator):
993-
return False
994-
995-
if is_graph_node_:
996-
# append_fn(id(node))
997-
if id(node) != next(fp_iterator):
998-
return False
999-
if node in ref_index:
1000-
# append_fn(ref_index[node])
1001-
return ref_index[node] == next(fp_iterator)
1002-
elif node in new_ref_index:
1003-
# append_fn(new_ref_index[node])
1004-
return new_ref_index[node] == next(fp_iterator)
1005-
index = new_ref_index[node] = ctx.next_index
1006-
ctx.next_index += 1
1007-
else:
1008-
index = -1
1009-
1010-
values, metadata = node_impl.flatten(node)
1011-
1012-
# append_fn(index)
1013-
if index != next(fp_iterator):
1014-
return False
1015-
# append_fn(metadata)
1016-
if metadata != next(fp_iterator):
1017-
return False
1018-
1019-
for key, value in values:
1020-
value_node_impl = get_node_impl(value)
1021-
# append_fn(key)
1022-
if key != next(fp_iterator):
1023-
return False
1024-
if value_node_impl is not None:
1025-
if not _check_graph_fingerprint(
1026-
ctx,
1027-
fp_iterator,
1028-
value,
1029-
value_node_impl,
1030-
ref_index,
1031-
new_ref_index,
1032-
):
1033-
return False
1034-
elif isinstance(value, Variable):
1035-
# append_fn(id(value))
1036-
if id(value) != next(fp_iterator):
1037-
return False
1038-
# append_fn(type(value))
1039-
if type(value) != next(fp_iterator):
1040-
return False
1041-
if value in ref_index:
1042-
# append_fn(ref_index[value])
1043-
if ref_index[value] != next(fp_iterator):
1044-
return False
1045-
elif value in new_ref_index:
1046-
# append_fn(new_ref_index[value])
1047-
if new_ref_index[value] != next(fp_iterator):
1048-
return False
1049-
else:
1050-
variable_index = new_ref_index[value] = ctx.next_index
1051-
ctx.next_index += 1
1052-
# append_fn(variable_index)
1053-
if variable_index != next(fp_iterator):
1054-
return False
1055-
for key_value in value.get_metadata().items():
1056-
# append_fn(key_value)
1057-
if key_value != next(fp_iterator):
1058-
return False
1059-
else:
1060-
if isinstance(value, (jax.Array, np.ndarray)):
1061-
raise ValueError(f'Arrays leaves are not supported: {value}')
1062-
# append_fn(value)
1063-
if value != next(fp_iterator):
1064-
return False
1065-
1066-
return True
1067-
1068-
1069868
def _get_sorted_leaves(
1070869
xs: tp.Mapping[tp.Any, tp.Any],
1071870
) -> list[tp.Any]:
@@ -1629,13 +1428,17 @@ def create_static_cache(x):
16291428
node_cache = unflatten(
16301429
graphdef, flat_state, index_ref=index_ref, copy_variables=False
16311430
)
1632-
cached_new_ref_index = RefMap()
1633-
_fp = fingerprint(
1431+
start_index = len(cached_ref_index)
1432+
flatten(
16341433
node_cache,
16351434
ref_index=cached_ref_index,
1636-
new_ref_index=cached_new_ref_index,
1435+
with_paths=False,
1436+
)
1437+
cached_new_ref_index = RefMap(
1438+
(key, value)
1439+
for key, value in cached_ref_index.items()
1440+
if value >= start_index
16371441
)
1638-
cached_ref_index.update(cached_new_ref_index)
16391442
cache[node_cache] = StaticCache.create(
16401443
graphdef, paths, variables, cached_new_ref_index
16411444
)

tests/nnx/graph_utils_test.py

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -845,44 +845,6 @@ def f(*pure_args):
845845
self.assertIs(m1, args_out[2]['b'])
846846
self.assertIs(m2, args_out[1])
847847

848-
def test_fingerprint_basic(self):
849-
m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
850-
fp1 = nnx.graph.fingerprint(m)
851-
fp2 = nnx.graph.fingerprint(m)
852-
853-
self.assertEqual(fp1, fp2)
854-
self.assertTrue(nnx.graph.check_fingerprint(m, fp1))
855-
self.assertTrue(nnx.graph.check_fingerprint(m, fp2))
856-
857-
def test_fingerprint_variable_id_sensitive(self):
858-
m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
859-
fp1 = nnx.graph.fingerprint(m1)
860-
861-
m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
862-
fp2 = nnx.graph.fingerprint(m2)
863-
864-
self.assertNotEqual(fp1, fp2)
865-
self.assertTrue(nnx.graph.check_fingerprint(m1, fp1))
866-
self.assertTrue(nnx.graph.check_fingerprint(m2, fp2))
867-
self.assertFalse(nnx.graph.check_fingerprint(m1, fp2))
868-
self.assertFalse(nnx.graph.check_fingerprint(m2, fp1))
869-
870-
def test_fingerprint_module_id_insensitive(self):
871-
m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
872-
m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
873-
874-
m1.kernel = m2.kernel
875-
m1.bias = m2.bias
876-
877-
fp1 = nnx.graph.fingerprint(m1)
878-
fp2 = nnx.graph.fingerprint(m2)
879-
880-
self.assertNotEqual(fp1, fp2)
881-
self.assertTrue(nnx.graph.check_fingerprint(m1, fp1))
882-
self.assertTrue(nnx.graph.check_fingerprint(m2, fp2))
883-
self.assertFalse(nnx.graph.check_fingerprint(m1, fp2))
884-
self.assertFalse(nnx.graph.check_fingerprint(m2, fp1))
885-
886848
def test_split_variable(self):
887849
v = nnx.Param(1)
888850
graphdef, state = nnx.split(v)
@@ -1137,6 +1099,27 @@ def test_iter_graph(self):
11371099
actual = [node for node in nodes if any(node is e for e in unique)]
11381100
self.assertEqual(list(map(index, actual)), list(map(index, expected)))
11391101

1102+
def test_cached_partial_docstring_example(self):
1103+
import optax
1104+
1105+
model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
1106+
optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=nnx.Param)
1107+
1108+
@nnx.jit
1109+
def train_step(model, optimizer, x, y):
1110+
def loss_fn(model):
1111+
return jnp.mean((model(x) - y) ** 2)
1112+
1113+
loss, grads = nnx.value_and_grad(loss_fn)(model)
1114+
optimizer.update(model, grads)
1115+
return loss
1116+
1117+
cached_train_step = nnx.cached_partial(train_step, model, optimizer)
1118+
1119+
for step in range(2):
1120+
x, y = jnp.ones((10, 2)), jnp.ones((10, 3))
1121+
loss = cached_train_step(x, y)
1122+
self.assertIsInstance(loss, jax.Array)
11401123

11411124
def test_find_duplicates(self):
11421125
class SharedModules(nnx.Module):

0 commit comments

Comments
 (0)