Skip to content

Commit ad5cf22

Browse files
committed
chore: add tree_unflatten method to tests
1 parent b9e32d7 commit ad5cf22

1 file changed

Lines changed: 10 additions & 4 deletions

File tree

tests/optimagic/parameters/test_tree_registry.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,18 @@ def bounds_df():
9797
)
9898

9999

100-
def test_tree_methods_with_empty_namespace(bounds_df):
101-
leaves, _ = tree_flatten(bounds_df)
100+
def test_tree_methods_with_default_namespace(bounds_df):
101+
leaves, treedef = tree_flatten(bounds_df)
102102
assert len(leaves) == 1
103103
assert_frame_equal(leaves[0], bounds_df)
104104

105105
leaves = tree_leaves(bounds_df)
106106
assert len(leaves) == 1
107107
assert_frame_equal(leaves[0], bounds_df)
108108

109+
tree = tree_unflatten(treedef, leaves)
110+
assert_frame_equal(tree, bounds_df)
111+
109112
names = leaf_names(bounds_df)
110113
expected_names = [""]
111114
assert names == expected_names
@@ -115,15 +118,18 @@ def test_tree_methods_with_empty_namespace(bounds_df):
115118

116119

117120
@pytest.mark.parametrize("namespace", OPTREE_NAMESPACES)
118-
def test_tree_methods_with_optimagic_namespace(namespace, bounds_df):
121+
def test_tree_methods_with_registered_namespaces(namespace, bounds_df):
119122
expected_leaves = bounds_df[namespace].tolist()
120123

121-
leaves, _ = tree_flatten(bounds_df, namespace=namespace)
124+
leaves, treedef = tree_flatten(bounds_df, namespace=namespace)
122125
assert leaves == expected_leaves
123126

124127
leaves = tree_leaves(bounds_df, namespace=namespace)
125128
assert leaves == expected_leaves
126129

130+
tree = tree_unflatten(treedef, leaves, namespace=namespace)
131+
assert_frame_equal(tree, bounds_df)
132+
127133
names = leaf_names(bounds_df, namespace=namespace)
128134
assert names == ["alpha", "beta", "gamma"]
129135

0 commit comments

Comments
 (0)