@@ -34,24 +34,26 @@ def tree_just_flatten(tree, is_leaf=None, namespace=""):
3434 return _with_insertion_order (namespace , optree .tree_leaves , tree , is_leaf = is_leaf )
3535
3636
37- def tree_unflatten (treedef , leaves , is_leaf = None , namespace = "" ):
37+ def tree_unflatten (treedef , leaves , namespace = "" ):
3838 """Reconstruct a pytree from the tree definition and the leaves."""
3939 _register_namespaces ()
4040
4141 if not isinstance (treedef , PyTreeSpec ):
42- treedef = _with_insertion_order (
43- namespace , optree .tree_structure , treedef , is_leaf = is_leaf
44- )
42+ treedef = _with_insertion_order (namespace , optree .tree_structure , treedef )
4543
44+ # optree.tree_unflatten doesn't need to be wrapped with _with_insertion_order
45+ # because it keeps the insertion order for dictionaries.
4646 return optree .tree_unflatten (treedef , leaves )
4747
4848
4949def tree_map (func , tree , is_leaf = None , namespace = "" ):
50- """Map an input function over pytree args to produce a new pytree."""
50+ """Map an input function over pytree args to produce a new pytree.
51+
52+ optree.tree_map always respects insertion order for dictionaries and doesn't
53+ require to be wrapped with _with_insertion_order.
54+ """
5155 _register_namespaces ()
52- return _with_insertion_order (
53- namespace , optree .tree_map , func , tree , is_leaf = is_leaf
54- )
56+ return optree .tree_map (func , tree , is_leaf = is_leaf , namespace = namespace )
5557
5658
5759def leaf_names (tree , is_leaf = None , namespace = "" , separator = "_" ):
0 commit comments