@@ -97,15 +97,18 @@ def bounds_df():
9797 )
9898
9999
100- def test_tree_methods_with_empty_namespace (bounds_df ):
101- leaves , _ = tree_flatten (bounds_df )
100+ def test_tree_methods_with_default_namespace (bounds_df ):
101+ leaves , treedef = tree_flatten (bounds_df )
102102 assert len (leaves ) == 1
103103 assert_frame_equal (leaves [0 ], bounds_df )
104104
105105 leaves = tree_leaves (bounds_df )
106106 assert len (leaves ) == 1
107107 assert_frame_equal (leaves [0 ], bounds_df )
108108
109+ tree = tree_unflatten (treedef , leaves )
110+ assert_frame_equal (tree , bounds_df )
111+
109112 names = leaf_names (bounds_df )
110113 expected_names = ["" ]
111114 assert names == expected_names
@@ -115,15 +118,18 @@ def test_tree_methods_with_empty_namespace(bounds_df):
115118
116119
117120@pytest .mark .parametrize ("namespace" , OPTREE_NAMESPACES )
118- def test_tree_methods_with_optimagic_namespace (namespace , bounds_df ):
121+ def test_tree_methods_with_registered_namespaces (namespace , bounds_df ):
119122 expected_leaves = bounds_df [namespace ].tolist ()
120123
121- leaves , _ = tree_flatten (bounds_df , namespace = namespace )
124+ leaves , treedef = tree_flatten (bounds_df , namespace = namespace )
122125 assert leaves == expected_leaves
123126
124127 leaves = tree_leaves (bounds_df , namespace = namespace )
125128 assert leaves == expected_leaves
126129
130+ tree = tree_unflatten (treedef , leaves , namespace = namespace )
131+ assert_frame_equal (tree , bounds_df )
132+
127133 names = leaf_names (bounds_df , namespace = namespace )
128134 assert names == ["alpha" , "beta" , "gamma" ]
129135
0 commit comments