Skip to content

Commit 8276c6d

Browse files
committed
chore: rearrange method order
1 parent b932c30 commit 8276c6d

1 file changed

Lines changed: 27 additions & 27 deletions

File tree

src/optimagic/parameters/tree_registry.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Wrapper around pybaum get_registry to tailor it to optimagic."""
1+
"""Wrapper around optree to tailor it to optimagic."""
22

33
import itertools
44
from functools import partial
@@ -88,32 +88,6 @@ def tree_equal(tree, other, is_leaf=None, namespace="", equality_checkers=None):
8888
return equal
8989

9090

91-
def _array_element_names(arr):
92-
dim_names = [map(str, range(n)) for n in arr.shape]
93-
names = list(map("_".join, itertools.product(*dim_names)))
94-
return names
95-
96-
97-
def _get_df_names(df):
98-
index_strings = list(df.index.map(_index_element_to_string))
99-
if "value" in df:
100-
out = index_strings
101-
else:
102-
out = ["_".join([loc, col]) for loc, col in product(index_strings, df.columns)]
103-
104-
return out
105-
106-
107-
def _index_element_to_string(element):
108-
if isinstance(element, (tuple, list)):
109-
as_strings = [str(entry) for entry in element]
110-
res_string = "_".join(as_strings)
111-
else:
112-
res_string = str(element)
113-
114-
return res_string
115-
116-
11791
def _flatten_df(df, data_col):
11892
is_value_df = "value" in df
11993
if is_value_df:
@@ -169,6 +143,32 @@ def _unflatten_jax_array(aux_data, leaves):
169143
return jnp.array(leaves).reshape(aux_data)
170144

171145

146+
def _get_df_names(df):
147+
index_strings = list(df.index.map(_index_element_to_string))
148+
if "value" in df:
149+
out = index_strings
150+
else:
151+
out = ["_".join([loc, col]) for loc, col in product(index_strings, df.columns)]
152+
153+
return out
154+
155+
156+
def _index_element_to_string(element):
157+
if isinstance(element, (tuple, list)):
158+
as_strings = [str(entry) for entry in element]
159+
res_string = "_".join(as_strings)
160+
else:
161+
res_string = str(element)
162+
163+
return res_string
164+
165+
166+
def _array_element_names(arr):
167+
dim_names = [map(str, range(n)) for n in arr.shape]
168+
names = list(map("_".join, itertools.product(*dim_names)))
169+
return names
170+
171+
172172
for namespace in optree_namespaces:
173173
optree.register_pytree_node(
174174
pd.DataFrame,

0 commit comments

Comments
 (0)