|
1 | | -"""Wrapper around pybaum get_registry to tailor it to optimagic.""" |
| 1 | +"""Wrapper around optree to tailor it to optimagic.""" |
2 | 2 |
|
3 | 3 | import itertools |
4 | 4 | from functools import partial |
@@ -88,32 +88,6 @@ def tree_equal(tree, other, is_leaf=None, namespace="", equality_checkers=None): |
88 | 88 | return equal |
89 | 89 |
|
90 | 90 |
|
91 | | -def _array_element_names(arr): |
92 | | - dim_names = [map(str, range(n)) for n in arr.shape] |
93 | | - names = list(map("_".join, itertools.product(*dim_names))) |
94 | | - return names |
95 | | - |
96 | | - |
97 | | -def _get_df_names(df): |
98 | | - index_strings = list(df.index.map(_index_element_to_string)) |
99 | | - if "value" in df: |
100 | | - out = index_strings |
101 | | - else: |
102 | | - out = ["_".join([loc, col]) for loc, col in product(index_strings, df.columns)] |
103 | | - |
104 | | - return out |
105 | | - |
106 | | - |
107 | | -def _index_element_to_string(element): |
108 | | - if isinstance(element, (tuple, list)): |
109 | | - as_strings = [str(entry) for entry in element] |
110 | | - res_string = "_".join(as_strings) |
111 | | - else: |
112 | | - res_string = str(element) |
113 | | - |
114 | | - return res_string |
115 | | - |
116 | | - |
117 | 91 | def _flatten_df(df, data_col): |
118 | 92 | is_value_df = "value" in df |
119 | 93 | if is_value_df: |
@@ -169,6 +143,32 @@ def _unflatten_jax_array(aux_data, leaves): |
169 | 143 | return jnp.array(leaves).reshape(aux_data) |
170 | 144 |
|
171 | 145 |
|
| 146 | +def _get_df_names(df): |
| 147 | + index_strings = list(df.index.map(_index_element_to_string)) |
| 148 | + if "value" in df: |
| 149 | + out = index_strings |
| 150 | + else: |
| 151 | + out = ["_".join([loc, col]) for loc, col in product(index_strings, df.columns)] |
| 152 | + |
| 153 | + return out |
| 154 | + |
| 155 | + |
| 156 | +def _index_element_to_string(element): |
| 157 | + if isinstance(element, (tuple, list)): |
| 158 | + as_strings = [str(entry) for entry in element] |
| 159 | + res_string = "_".join(as_strings) |
| 160 | + else: |
| 161 | + res_string = str(element) |
| 162 | + |
| 163 | + return res_string |
| 164 | + |
| 165 | + |
| 166 | +def _array_element_names(arr): |
| 167 | + dim_names = [map(str, range(n)) for n in arr.shape] |
| 168 | + names = list(map("_".join, itertools.product(*dim_names))) |
| 169 | + return names |
| 170 | + |
| 171 | + |
172 | 172 | for namespace in optree_namespaces: |
173 | 173 | optree.register_pytree_node( |
174 | 174 | pd.DataFrame, |
|
0 commit comments