|
9 | 9 | import pandas as pd |
10 | 10 | from optree.pytree import PyTreeSpec |
11 | 11 |
|
12 | | -from optimagic.typing import OPTREE_NAMESPACES |
| 12 | +from optimagic.typing import DEFAULT_NAMESPACE, OPTREE_NAMESPACES |
13 | 13 |
|
14 | 14 | try: |
15 | 15 | import jax.numpy as jnp # type: ignore[import-not-found] |
|
21 | 21 | _are_namespaces_registered = False |
22 | 22 |
|
23 | 23 |
|
24 | | -def tree_flatten(tree, is_leaf=None, namespace=""): |
| 24 | +def tree_flatten(tree, is_leaf=None, namespace=DEFAULT_NAMESPACE): |
25 | 25 | """Flatten a pytree.""" |
26 | 26 | _register_namespaces() |
27 | 27 | _check_namespace(namespace) |
| 28 | + with optree.dict_insertion_ordered(True, namespace=namespace): |
| 29 | + return optree.tree_flatten(tree, is_leaf=is_leaf, namespace=namespace) |
28 | 30 |
|
29 | | - return _with_insertion_order(namespace, optree.tree_flatten, tree, is_leaf=is_leaf) |
30 | 31 |
|
31 | | - |
32 | | -def tree_just_flatten(tree, is_leaf=None, namespace=""): |
| 32 | +def tree_just_flatten(tree, is_leaf=None, namespace=DEFAULT_NAMESPACE): |
33 | 33 | """Get the leaves of a pytree.""" |
34 | 34 | _register_namespaces() |
35 | 35 | _check_namespace(namespace) |
36 | | - |
37 | | - return _with_insertion_order(namespace, optree.tree_leaves, tree, is_leaf=is_leaf) |
| 36 | + with optree.dict_insertion_ordered(True, namespace=namespace): |
| 37 | + return optree.tree_leaves(tree, is_leaf=is_leaf, namespace=namespace) |
38 | 38 |
|
39 | 39 |
|
40 | | -def tree_unflatten(treedef, leaves, namespace=""): |
| 40 | +def tree_unflatten(treedef, leaves, namespace=DEFAULT_NAMESPACE): |
41 | 41 | """Reconstruct a pytree from the tree definition and the leaves.""" |
42 | 42 | _register_namespaces() |
43 | 43 |
|
44 | 44 | if not isinstance(treedef, PyTreeSpec): |
45 | 45 | _check_namespace(namespace) |
46 | | - treedef = _with_insertion_order(namespace, optree.tree_structure, treedef) |
| 46 | + with optree.dict_insertion_ordered(True, namespace=namespace): |
| 47 | + treedef = optree.tree_structure(treedef, namespace=namespace) |
47 | 48 |
|
48 | | - # optree.tree_unflatten doesn't need to be wrapped with _with_insertion_order |
49 | | - # because it keeps the insertion order for dictionaries. |
| 49 | + # Doesn't need to be wrapped with dict_insertion_ordered |
| 50 | + # because it keeps the insertion order for dictionaries by default. |
50 | 51 | return optree.tree_unflatten(treedef, leaves) |
51 | 52 |
|
52 | 53 |
|
53 | | -def tree_map(func, tree, is_leaf=None, namespace=""): |
54 | | - """Map an input function over pytree args to produce a new pytree. |
55 | | -
|
56 | | - optree.tree_map always respects insertion order for dictionaries and doesn't |
57 | | - require to be wrapped with _with_insertion_order. |
58 | | - """ |
| 54 | +def tree_map(func, tree, is_leaf=None, namespace=DEFAULT_NAMESPACE): |
| 55 | + """Map an input function over pytree args to produce a new pytree.""" |
59 | 56 | _register_namespaces() |
60 | 57 | _check_namespace(namespace) |
61 | 58 |
|
| 59 | + # Doesn't need to be wrapped with dict_insertion_ordered |
| 60 | + # because it keeps the insertion order for dictionaries by default. |
62 | 61 | return optree.tree_map(func, tree, is_leaf=is_leaf, namespace=namespace) |
63 | 62 |
|
64 | 63 |
|
65 | | -def leaf_names(tree, is_leaf=None, namespace="", separator="_"): |
| 64 | +def leaf_names(tree, is_leaf=None, namespace=DEFAULT_NAMESPACE, separator="_"): |
66 | 65 | """Get the path names for tree leaves.""" |
67 | 66 | _register_namespaces() |
68 | 67 | _check_namespace(namespace) |
69 | 68 |
|
70 | | - paths, _, _ = _with_insertion_order( |
71 | | - namespace, optree.tree_flatten_with_path, tree, is_leaf=is_leaf |
72 | | - ) |
| 69 | + with optree.dict_insertion_ordered(True, namespace=namespace): |
| 70 | + paths, _, _ = optree.tree_flatten_with_path( |
| 71 | + tree, is_leaf=is_leaf, namespace=namespace |
| 72 | + ) |
73 | 73 | return [separator.join(str(p) for p in path) for path in paths] |
74 | 74 |
|
75 | 75 |
|
76 | | -def _with_insertion_order(namespace, optree_func, *args, **kwargs): |
77 | | - """Call an optree function, preserving dict key order within a namespace. |
78 | | -
|
79 | | - By default, optree sorts dictionary keys. When a namespace is provided, |
80 | | - this wrapper enables dict_insertion_ordered mode so that the original |
81 | | - key order is preserved in the output. |
82 | | - """ |
83 | | - if namespace: |
84 | | - with optree.dict_insertion_ordered(True, namespace=namespace): |
85 | | - return optree_func(*args, namespace=namespace, **kwargs) |
86 | | - return optree_func(*args, namespace=namespace, **kwargs) |
87 | | - |
88 | | - |
89 | | -def tree_equal(tree, other, is_leaf=None, namespace="", equality_checkers=None): |
| 76 | +def tree_equal( |
| 77 | + tree, other, is_leaf=None, namespace=DEFAULT_NAMESPACE, equality_checkers=None |
| 78 | +): |
90 | 79 | """Check the equality between two trees.""" |
91 | 80 | equality_checkers = ( |
92 | 81 | _get_equality_checkers() |
@@ -136,8 +125,8 @@ def _get_equality_checkers(): |
136 | 125 |
|
137 | 126 |
|
138 | 127 | def _check_namespace(namespace: str) -> None: |
139 | | - """Checks if the namespace is a registered and raise a warning.""" |
140 | | - if namespace and namespace not in OPTREE_NAMESPACES: |
| 128 | + """Checks if the namespace is registered and raise a warning.""" |
| 129 | + if namespace != DEFAULT_NAMESPACE and namespace not in OPTREE_NAMESPACES: |
141 | 130 | warnings.warn( |
142 | 131 | f"Namespace '{namespace}' is not registered. " |
143 | 132 | f"Registered namespaces are: {','.join(OPTREE_NAMESPACES)}. " |
|
0 commit comments