@@ -865,207 +865,6 @@ def make_mutable_arraydef(value: variablelib.Ref):
865865 return
866866
867867
868- @dataclasses .dataclass (slots = True )
869- class FingerprintContext :
870- next_index : int
871-
872-
873- # TODO(cgarciae): the actual fingerprint object is not being used,
874- # only the traversal process is still relevant
875- def fingerprint (
876- node ,
877- / ,
878- * ,
879- ref_index : RefMap | None = None ,
880- new_ref_index : RefMap | None = None ,
881- ) -> list [tp .Hashable ]:
882- """ """
883- if ref_index is None :
884- ref_index = RefMap ()
885-
886- if new_ref_index is None :
887- new_ref_index = RefMap ()
888- node_impl = get_node_impl (node )
889- if node_impl is None :
890- raise RuntimeError (f'Unsupported type: { type (node )} , this is a bug.' )
891- ctx = FingerprintContext (len (ref_index ) + len (new_ref_index ))
892- fp : list [tp .Hashable ] = []
893- _graph_fingerprint (ctx , fp .append , node , node_impl , ref_index , new_ref_index )
894- return fp
895-
896-
897- def _graph_fingerprint (
898- ctx : FingerprintContext ,
899- append_fn : tp .Callable [[tp .Any ], None ],
900- node ,
901- node_impl : NodeImpl [Node , Leaf , AuxData ],
902- ref_index : RefMap ,
903- new_ref_index : RefMap ,
904- ):
905- is_pytree_node_ = type (node_impl ) is PytreeNodeImpl
906- is_graph_node_ = type (node_impl ) is GraphNodeImpl
907-
908- append_fn (type (node ))
909-
910- if is_graph_node_ :
911- append_fn (id (node ))
912- if node in ref_index :
913- append_fn (ref_index [node ])
914- return
915- elif node in new_ref_index :
916- append_fn (new_ref_index [node ])
917- return
918- index = new_ref_index [node ] = ctx .next_index
919- ctx .next_index += 1
920- else :
921- index = - 1
922-
923- values , metadata = node_impl .flatten (node )
924-
925- append_fn (index )
926- append_fn (metadata )
927-
928- for key , value in values :
929- value_node_impl = get_node_impl (value )
930- append_fn (key )
931- if value_node_impl is not None :
932- _graph_fingerprint (
933- ctx ,
934- append_fn ,
935- value ,
936- value_node_impl ,
937- ref_index ,
938- new_ref_index ,
939- )
940- elif isinstance (value , Variable ):
941- append_fn (id (value ))
942- append_fn (type (value ))
943- if value in ref_index :
944- append_fn (ref_index [value ])
945- elif value in new_ref_index :
946- append_fn (new_ref_index [value ])
947- else :
948- variable_index = new_ref_index [value ] = ctx .next_index
949- ctx .next_index += 1
950- append_fn (variable_index )
951- for key_value in value .get_metadata ().items ():
952- append_fn (key_value )
953- elif not isinstance (value , (jax .Array , np .ndarray )):
954- append_fn (value )
955-
956-
957- def check_fingerprint (
958- node ,
959- fp : list [tp .Hashable ],
960- / ,
961- * ,
962- ref_index : RefMap | None = None ,
963- new_ref_index : RefMap | None = None ,
964- ) -> bool :
965- """ """
966- if ref_index is None :
967- ref_index = RefMap ()
968-
969- if new_ref_index is None :
970- new_ref_index = RefMap ()
971- node_impl = get_node_impl (node )
972- if node_impl is None :
973- raise RuntimeError (f'Unsupported type: { type (node )} , this is a bug.' )
974- ctx = FingerprintContext (len (ref_index ) + len (new_ref_index ))
975- fp_matches = _check_graph_fingerprint (
976- ctx , iter (fp ), node , node_impl , ref_index , new_ref_index
977- )
978- return fp_matches
979-
980-
981- def _check_graph_fingerprint (
982- ctx : FingerprintContext ,
983- fp_iterator : tp .Iterator [tp .Hashable ],
984- node ,
985- node_impl : NodeImpl [Node , Leaf , AuxData ],
986- ref_index : RefMap ,
987- new_ref_index : RefMap ,
988- ) -> bool :
989- is_pytree_node_ = type (node_impl ) is PytreeNodeImpl
990- is_graph_node_ = type (node_impl ) is GraphNodeImpl
991-
992- if type (node ) != next (fp_iterator ):
993- return False
994-
995- if is_graph_node_ :
996- # append_fn(id(node))
997- if id (node ) != next (fp_iterator ):
998- return False
999- if node in ref_index :
1000- # append_fn(ref_index[node])
1001- return ref_index [node ] == next (fp_iterator )
1002- elif node in new_ref_index :
1003- # append_fn(new_ref_index[node])
1004- return new_ref_index [node ] == next (fp_iterator )
1005- index = new_ref_index [node ] = ctx .next_index
1006- ctx .next_index += 1
1007- else :
1008- index = - 1
1009-
1010- values , metadata = node_impl .flatten (node )
1011-
1012- # append_fn(index)
1013- if index != next (fp_iterator ):
1014- return False
1015- # append_fn(metadata)
1016- if metadata != next (fp_iterator ):
1017- return False
1018-
1019- for key , value in values :
1020- value_node_impl = get_node_impl (value )
1021- # append_fn(key)
1022- if key != next (fp_iterator ):
1023- return False
1024- if value_node_impl is not None :
1025- if not _check_graph_fingerprint (
1026- ctx ,
1027- fp_iterator ,
1028- value ,
1029- value_node_impl ,
1030- ref_index ,
1031- new_ref_index ,
1032- ):
1033- return False
1034- elif isinstance (value , Variable ):
1035- # append_fn(id(value))
1036- if id (value ) != next (fp_iterator ):
1037- return False
1038- # append_fn(type(value))
1039- if type (value ) != next (fp_iterator ):
1040- return False
1041- if value in ref_index :
1042- # append_fn(ref_index[value])
1043- if ref_index [value ] != next (fp_iterator ):
1044- return False
1045- elif value in new_ref_index :
1046- # append_fn(new_ref_index[value])
1047- if new_ref_index [value ] != next (fp_iterator ):
1048- return False
1049- else :
1050- variable_index = new_ref_index [value ] = ctx .next_index
1051- ctx .next_index += 1
1052- # append_fn(variable_index)
1053- if variable_index != next (fp_iterator ):
1054- return False
1055- for key_value in value .get_metadata ().items ():
1056- # append_fn(key_value)
1057- if key_value != next (fp_iterator ):
1058- return False
1059- else :
1060- if isinstance (value , (jax .Array , np .ndarray )):
1061- raise ValueError (f'Arrays leaves are not supported: { value } ' )
1062- # append_fn(value)
1063- if value != next (fp_iterator ):
1064- return False
1065-
1066- return True
1067-
1068-
1069868def _get_sorted_leaves (
1070869 xs : tp .Mapping [tp .Any , tp .Any ],
1071870) -> list [tp .Any ]:
@@ -1629,13 +1428,17 @@ def create_static_cache(x):
16291428 node_cache = unflatten (
16301429 graphdef , flat_state , index_ref = index_ref , copy_variables = False
16311430 )
1632- cached_new_ref_index = RefMap ( )
1633- _fp = fingerprint (
1431+ start_index = len ( cached_ref_index )
1432+ flatten (
16341433 node_cache ,
16351434 ref_index = cached_ref_index ,
1636- new_ref_index = cached_new_ref_index ,
1435+ with_paths = False ,
1436+ )
1437+ cached_new_ref_index = RefMap (
1438+ (key , value )
1439+ for key , value in cached_ref_index .items ()
1440+ if value >= start_index
16371441 )
1638- cached_ref_index .update (cached_new_ref_index )
16391442 cache [node_cache ] = StaticCache .create (
16401443 graphdef , paths , variables , cached_new_ref_index
16411444 )
0 commit comments