Skip to content

Commit 9d8c744

Browse files
committed
chore: use namespaces for passing data_col value for dataframes
1 parent 5aa2b08 commit 9d8c744

4 files changed

Lines changed: 101 additions & 116 deletions

File tree

src/optimagic/parameters/bounds.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from optimagic.parameters.tree_registry import (
1212
extended,
1313
leaf_names,
14-
set_data_col_df_attribute,
1514
tree_map,
1615
)
1716
from optimagic.parameters.tree_registry import (
@@ -181,9 +180,7 @@ def _update_bounds_and_flatten(
181180
np.ndarray: The updated and flattened bounds.
182181
183182
"""
184-
flat_nan_tree = tree_leaves(
185-
set_data_col_df_attribute(nan_tree, data_col=kind), namespace=extended
186-
)
183+
flat_nan_tree = tree_leaves(nan_tree, namespace=kind)
187184
if bounds is not None:
188185
flat_bounds = tree_leaves(bounds, namespace=extended)
189186

Lines changed: 100 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,34 @@
11
"""Wrapper around pybaum get_registry to tailor it to optimagic."""
22

33
import itertools
4+
from functools import partial
45
from itertools import product
56

67
import numpy as np
78
import optree
89
import pandas as pd
910
from optree.pytree import PyTreeSpec
1011

11-
from optimagic.typing import extended_namespace
12+
extended = "value"
13+
namespaces = [
14+
extended,
15+
"lower_bound",
16+
"upper_bound",
17+
"soft_lower_bound",
18+
"soft_upper_bound",
19+
]
1220

13-
14-
def _get_df_names(df):
15-
index_strings = list(df.index.map(_index_element_to_string))
16-
if "value" in df:
17-
out = index_strings
18-
else:
19-
out = ["_".join([loc, col]) for loc, col in product(index_strings, df.columns)]
20-
21-
return out
22-
23-
24-
def _index_element_to_string(element):
25-
if isinstance(element, (tuple, list)):
26-
as_strings = [str(entry) for entry in element]
27-
res_string = "_".join(as_strings)
28-
else:
29-
res_string = str(element)
30-
31-
return res_string
21+
EQUALITY_CHECKERS = {}
22+
EQUALITY_CHECKERS[np.ndarray.__name__] = lambda a, b: bool((a == b).all())
23+
EQUALITY_CHECKERS[pd.Series.__name__] = lambda a, b: a.equals(b)
24+
EQUALITY_CHECKERS[pd.DataFrame.__name__] = lambda a, b: a.equals(b)
3225

3326

3427
def tree_flatten(tree, is_leaf=None, namespace=""):
35-
with optree.dict_insertion_ordered(True, namespace=extended_namespace):
28+
if namespace:
29+
with optree.dict_insertion_ordered(True, namespace=namespace):
30+
return optree.tree_flatten(tree, is_leaf=is_leaf, namespace=namespace)
31+
else:
3632
return optree.tree_flatten(tree, is_leaf=is_leaf, namespace=namespace)
3733

3834

@@ -41,9 +37,6 @@ def tree_just_flatten(tree, is_leaf=None, namespace=""):
4137
return leaves
4238

4339

44-
extended = extended_namespace
45-
46-
4740
def tree_unflatten(treedef, leaves, is_leaf=None, namespace=""):
4841
if not isinstance(treedef, PyTreeSpec):
4942
_, treedef = tree_flatten(treedef, is_leaf=is_leaf, namespace=namespace)
@@ -60,14 +53,35 @@ def leaf_names(tree, is_leaf=None, namespace="", separator="_"):
6053
return [separator.join(str(p) for p in path) for path in paths]
6154

6255

63-
def set_data_col_df_attribute(tree, data_col):
64-
def set_attr(node):
65-
if isinstance(node, pd.DataFrame):
66-
node = node.copy()
67-
node.attrs["data_col"] = data_col
68-
return node
56+
def tree_equal(tree, other, is_leaf=None, namespace="", equality_checkers=None):
57+
equality_checkers = (
58+
EQUALITY_CHECKERS
59+
if equality_checkers is None
60+
else {**EQUALITY_CHECKERS, **equality_checkers}
61+
)
62+
63+
first_flat, first_treespec = tree_flatten(
64+
tree, is_leaf=is_leaf, namespace=namespace
65+
)
66+
second_flat, second_treespec = tree_flatten(
67+
other, is_leaf=is_leaf, namespace=namespace
68+
)
6969

70-
return tree_map(set_attr, tree)
70+
first_names = leaf_names(tree, is_leaf=is_leaf, namespace=namespace)
71+
second_names = leaf_names(other, is_leaf=is_leaf, namespace=namespace)
72+
73+
equal = first_names == second_names and first_treespec == second_treespec
74+
75+
if equal:
76+
for first, second in zip(first_flat, second_flat, strict=True):
77+
check_func = equality_checkers.get(
78+
type(first).__name__, lambda a, b: a == b
79+
)
80+
equal = equal and check_func(first, second)
81+
if not equal:
82+
break
83+
84+
return equal
7185

7286

7387
def _array_element_names(arr):
@@ -76,8 +90,27 @@ def _array_element_names(arr):
7690
return names
7791

7892

79-
def _flatten_df_optree(df):
80-
data_col = df.attrs.get("data_col", "value")
93+
def _get_df_names(df):
94+
index_strings = list(df.index.map(_index_element_to_string))
95+
if "value" in df:
96+
out = index_strings
97+
else:
98+
out = ["_".join([loc, col]) for loc, col in product(index_strings, df.columns)]
99+
100+
return out
101+
102+
103+
def _index_element_to_string(element):
104+
if isinstance(element, (tuple, list)):
105+
as_strings = [str(entry) for entry in element]
106+
res_string = "_".join(as_strings)
107+
else:
108+
res_string = str(element)
109+
110+
return res_string
111+
112+
113+
def _flatten_df(df, data_col):
81114
is_value_df = "value" in df
82115
if is_value_df:
83116
flat = df.get(data_col, default=np.full(len(df), np.nan)).tolist()
@@ -91,8 +124,7 @@ def _flatten_df_optree(df):
91124
return flat, aux_data, _get_df_names(df)
92125

93126

94-
def _unflatten_df_optree(aux_data, leaves):
95-
data_col = aux_data["df"].attrs.get("data_col", "value")
127+
def _unflatten_df(aux_data, leaves, data_col):
96128
if aux_data["is_value_df"]:
97129
out = aux_data["df"].assign(**{data_col: leaves})
98130
else:
@@ -104,61 +136,44 @@ def _unflatten_df_optree(aux_data, leaves):
104136
return out
105137

106138

107-
optree.register_pytree_node(
108-
pd.DataFrame,
109-
_flatten_df_optree,
110-
_unflatten_df_optree,
111-
namespace=extended_namespace,
112-
)
113-
114-
optree.register_pytree_node(
115-
pd.Series,
116-
lambda sr: (
117-
sr.tolist(),
118-
{"index": sr.index, "name": sr.name},
119-
list(sr.index.map(_index_element_to_string)),
120-
),
121-
lambda aux_data, leaves: pd.Series(leaves, **aux_data),
122-
namespace=extended_namespace,
123-
)
124-
125-
optree.register_pytree_node(
126-
np.ndarray,
127-
lambda arr: (arr.flatten().tolist(), arr.shape, _array_element_names(arr)),
128-
lambda aux_data, leaves: np.array(leaves).reshape(aux_data),
129-
namespace=extended_namespace,
130-
)
139+
def _flatten_series(series: pd.Series):
140+
return (
141+
series.tolist(),
142+
{"index": series.index, "name": series.name},
143+
list(series.index.map(_index_element_to_string)),
144+
)
131145

132-
EQUALITY_CHECKERS = {}
133-
EQUALITY_CHECKERS[np.ndarray] = lambda a, b: bool((a == b).all())
134-
EQUALITY_CHECKERS[pd.Series] = lambda a, b: a.equals(b)
135-
EQUALITY_CHECKERS[pd.DataFrame] = lambda a, b: a.equals(b)
136146

147+
def _unflatten_series(aux_data, leaves):
148+
return pd.Series(leaves, **aux_data)
137149

138-
def tree_equal(tree, other, is_leaf=None, namespace="", equality_checkers=None):
139-
equality_checkers = (
140-
EQUALITY_CHECKERS
141-
if equality_checkers is None
142-
else {**EQUALITY_CHECKERS, **equality_checkers}
143-
)
144150

145-
first_flat, first_treespec = tree_flatten(
146-
tree, is_leaf=is_leaf, namespace=namespace
147-
)
148-
second_flat, second_treespec = tree_flatten(
149-
other, is_leaf=is_leaf, namespace=namespace
150-
)
151+
def _flatten_ndarray(arr: np.ndarray):
152+
return arr.flatten().tolist(), arr.shape, _array_element_names(arr)
151153

152-
first_names = leaf_names(tree, is_leaf=is_leaf, namespace=namespace)
153-
second_names = leaf_names(tree, is_leaf=is_leaf, namespace=namespace)
154154

155-
equal = first_names == second_names and first_treespec == second_treespec
155+
def _unflatten_ndarray(aux_data, leaves):
156+
return np.array(leaves).reshape(aux_data)
156157

157-
if equal:
158-
for first, second in zip(first_flat, second_flat, strict=True):
159-
check_func = equality_checkers.get(type(first), lambda a, b: a == b)
160-
equal = equal and check_func(first, second)
161-
if not equal:
162-
break
163158

164-
return equal
159+
for namespace in namespaces:
160+
optree.register_pytree_node(
161+
pd.DataFrame,
162+
partial(_flatten_df, data_col=namespace),
163+
partial(_unflatten_df, data_col=namespace),
164+
namespace=namespace,
165+
)
166+
167+
optree.register_pytree_node(
168+
pd.Series,
169+
_flatten_series,
170+
_unflatten_series,
171+
namespace=namespace,
172+
)
173+
174+
optree.register_pytree_node(
175+
np.ndarray,
176+
_flatten_ndarray,
177+
_unflatten_ndarray,
178+
namespace=namespace,
179+
)

src/optimagic/typing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
Scalar = Any
2323

2424
T = TypeVar("T")
25-
extended_namespace = "extended_namespace"
2625

2726

2827
class AggregationLevel(Enum):

tests/optimagic/parameters/test_tree_registry.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from optimagic.parameters.tree_registry import (
77
extended,
88
leaf_names,
9-
set_data_col_df_attribute,
109
tree_flatten,
1110
tree_unflatten,
1211
)
@@ -61,28 +60,3 @@ def test_unflatten_partially_numeric_df(other_df):
6160
def test_leaf_names_partially_numeric_df(other_df):
6261
names = leaf_names(other_df, namespace=extended)
6362
assert names == ["alpha_b", "alpha_c", "beta_b", "beta_c", "gamma_b", "gamma_c"]
64-
65-
66-
def test_set_data_col_attribute_assigns_attribute(value_df):
67-
df = set_data_col_df_attribute(value_df, data_col="attr")
68-
assert df.attrs.get("data_col") == "attr"
69-
assert value_df.attrs.get("data_col") is None
70-
71-
72-
def test_set_data_col_attribute_unflattened_tree_has_attribute(value_df):
73-
df = set_data_col_df_attribute(value_df, data_col="attr")
74-
tree, treedef = tree_flatten(df, namespace=extended)
75-
df = tree_unflatten(treedef, tree)
76-
assert df.attrs.get("data_col") == "attr"
77-
78-
79-
def test_set_data_col_attribute_returns_nan(value_df):
80-
df = set_data_col_df_attribute(value_df, data_col="attr")
81-
tree, treedef = tree_flatten(df, namespace=extended)
82-
assert all(np.isnan(value) for value in tree)
83-
84-
85-
def test_set_data_col_attribute_returs_column_values(value_df):
86-
df = set_data_col_df_attribute(value_df, data_col="a")
87-
tree, treedef = tree_flatten(df, namespace=extended)
88-
assert tree == [0, 2, 4]

0 commit comments

Comments
 (0)