11"""Wrapper around pybaum get_registry to tailor it to optimagic."""
22
33import itertools
4+ from functools import partial
45from itertools import product
56
67import numpy as np
78import optree
89import pandas as pd
910from optree .pytree import PyTreeSpec
1011
11- from optimagic .typing import extended_namespace
12+ extended = "value"
13+ namespaces = [
14+ extended ,
15+ "lower_bound" ,
16+ "upper_bound" ,
17+ "soft_lower_bound" ,
18+ "soft_upper_bound" ,
19+ ]
1220
13-
14- def _get_df_names (df ):
15- index_strings = list (df .index .map (_index_element_to_string ))
16- if "value" in df :
17- out = index_strings
18- else :
19- out = ["_" .join ([loc , col ]) for loc , col in product (index_strings , df .columns )]
20-
21- return out
22-
23-
24- def _index_element_to_string (element ):
25- if isinstance (element , (tuple , list )):
26- as_strings = [str (entry ) for entry in element ]
27- res_string = "_" .join (as_strings )
28- else :
29- res_string = str (element )
30-
31- return res_string
21+ EQUALITY_CHECKERS = {}
22+ EQUALITY_CHECKERS [np .ndarray .__name__ ] = lambda a , b : bool ((a == b ).all ())
23+ EQUALITY_CHECKERS [pd .Series .__name__ ] = lambda a , b : a .equals (b )
24+ EQUALITY_CHECKERS [pd .DataFrame .__name__ ] = lambda a , b : a .equals (b )
3225
3326
3427def tree_flatten (tree , is_leaf = None , namespace = "" ):
35- with optree .dict_insertion_ordered (True , namespace = extended_namespace ):
28+ if namespace :
29+ with optree .dict_insertion_ordered (True , namespace = namespace ):
30+ return optree .tree_flatten (tree , is_leaf = is_leaf , namespace = namespace )
31+ else :
3632 return optree .tree_flatten (tree , is_leaf = is_leaf , namespace = namespace )
3733
3834
@@ -41,9 +37,6 @@ def tree_just_flatten(tree, is_leaf=None, namespace=""):
4137 return leaves
4238
4339
44- extended = extended_namespace
45-
46-
4740def tree_unflatten (treedef , leaves , is_leaf = None , namespace = "" ):
4841 if not isinstance (treedef , PyTreeSpec ):
4942 _ , treedef = tree_flatten (treedef , is_leaf = is_leaf , namespace = namespace )
@@ -60,14 +53,35 @@ def leaf_names(tree, is_leaf=None, namespace="", separator="_"):
6053 return [separator .join (str (p ) for p in path ) for path in paths ]
6154
6255
63- def set_data_col_df_attribute (tree , data_col ):
64- def set_attr (node ):
65- if isinstance (node , pd .DataFrame ):
66- node = node .copy ()
67- node .attrs ["data_col" ] = data_col
68- return node
56+ def tree_equal (tree , other , is_leaf = None , namespace = "" , equality_checkers = None ):
57+ equality_checkers = (
58+ EQUALITY_CHECKERS
59+ if equality_checkers is None
60+ else {** EQUALITY_CHECKERS , ** equality_checkers }
61+ )
62+
63+ first_flat , first_treespec = tree_flatten (
64+ tree , is_leaf = is_leaf , namespace = namespace
65+ )
66+ second_flat , second_treespec = tree_flatten (
67+ other , is_leaf = is_leaf , namespace = namespace
68+ )
6969
70- return tree_map (set_attr , tree )
70+ first_names = leaf_names (tree , is_leaf = is_leaf , namespace = namespace )
71+ second_names = leaf_names (other , is_leaf = is_leaf , namespace = namespace )
72+
73+ equal = first_names == second_names and first_treespec == second_treespec
74+
75+ if equal :
76+ for first , second in zip (first_flat , second_flat , strict = True ):
77+ check_func = equality_checkers .get (
78+ type (first ).__name__ , lambda a , b : a == b
79+ )
80+ equal = equal and check_func (first , second )
81+ if not equal :
82+ break
83+
84+ return equal
7185
7286
7387def _array_element_names (arr ):
@@ -76,8 +90,27 @@ def _array_element_names(arr):
7690 return names
7791
7892
79- def _flatten_df_optree (df ):
80- data_col = df .attrs .get ("data_col" , "value" )
93+ def _get_df_names (df ):
94+ index_strings = list (df .index .map (_index_element_to_string ))
95+ if "value" in df :
96+ out = index_strings
97+ else :
98+ out = ["_" .join ([loc , col ]) for loc , col in product (index_strings , df .columns )]
99+
100+ return out
101+
102+
103+ def _index_element_to_string (element ):
104+ if isinstance (element , (tuple , list )):
105+ as_strings = [str (entry ) for entry in element ]
106+ res_string = "_" .join (as_strings )
107+ else :
108+ res_string = str (element )
109+
110+ return res_string
111+
112+
113+ def _flatten_df (df , data_col ):
81114 is_value_df = "value" in df
82115 if is_value_df :
83116 flat = df .get (data_col , default = np .full (len (df ), np .nan )).tolist ()
@@ -91,8 +124,7 @@ def _flatten_df_optree(df):
91124 return flat , aux_data , _get_df_names (df )
92125
93126
94- def _unflatten_df_optree (aux_data , leaves ):
95- data_col = aux_data ["df" ].attrs .get ("data_col" , "value" )
127+ def _unflatten_df (aux_data , leaves , data_col ):
96128 if aux_data ["is_value_df" ]:
97129 out = aux_data ["df" ].assign (** {data_col : leaves })
98130 else :
@@ -104,61 +136,44 @@ def _unflatten_df_optree(aux_data, leaves):
104136 return out
105137
106138
107- optree .register_pytree_node (
108- pd .DataFrame ,
109- _flatten_df_optree ,
110- _unflatten_df_optree ,
111- namespace = extended_namespace ,
112- )
113-
114- optree .register_pytree_node (
115- pd .Series ,
116- lambda sr : (
117- sr .tolist (),
118- {"index" : sr .index , "name" : sr .name },
119- list (sr .index .map (_index_element_to_string )),
120- ),
121- lambda aux_data , leaves : pd .Series (leaves , ** aux_data ),
122- namespace = extended_namespace ,
123- )
124-
125- optree .register_pytree_node (
126- np .ndarray ,
127- lambda arr : (arr .flatten ().tolist (), arr .shape , _array_element_names (arr )),
128- lambda aux_data , leaves : np .array (leaves ).reshape (aux_data ),
129- namespace = extended_namespace ,
130- )
139+ def _flatten_series (series : pd .Series ):
140+ return (
141+ series .tolist (),
142+ {"index" : series .index , "name" : series .name },
143+ list (series .index .map (_index_element_to_string )),
144+ )
131145
132- EQUALITY_CHECKERS = {}
133- EQUALITY_CHECKERS [np .ndarray ] = lambda a , b : bool ((a == b ).all ())
134- EQUALITY_CHECKERS [pd .Series ] = lambda a , b : a .equals (b )
135- EQUALITY_CHECKERS [pd .DataFrame ] = lambda a , b : a .equals (b )
136146
147+ def _unflatten_series (aux_data , leaves ):
148+ return pd .Series (leaves , ** aux_data )
137149
138- def tree_equal (tree , other , is_leaf = None , namespace = "" , equality_checkers = None ):
139- equality_checkers = (
140- EQUALITY_CHECKERS
141- if equality_checkers is None
142- else {** EQUALITY_CHECKERS , ** equality_checkers }
143- )
144150
145- first_flat , first_treespec = tree_flatten (
146- tree , is_leaf = is_leaf , namespace = namespace
147- )
148- second_flat , second_treespec = tree_flatten (
149- other , is_leaf = is_leaf , namespace = namespace
150- )
151+ def _flatten_ndarray (arr : np .ndarray ):
152+ return arr .flatten ().tolist (), arr .shape , _array_element_names (arr )
151153
152- first_names = leaf_names (tree , is_leaf = is_leaf , namespace = namespace )
153- second_names = leaf_names (tree , is_leaf = is_leaf , namespace = namespace )
154154
155- equal = first_names == second_names and first_treespec == second_treespec
155+ def _unflatten_ndarray (aux_data , leaves ):
156+ return np .array (leaves ).reshape (aux_data )
156157
157- if equal :
158- for first , second in zip (first_flat , second_flat , strict = True ):
159- check_func = equality_checkers .get (type (first ), lambda a , b : a == b )
160- equal = equal and check_func (first , second )
161- if not equal :
162- break
163158
164- return equal
159+ for namespace in namespaces :
160+ optree .register_pytree_node (
161+ pd .DataFrame ,
162+ partial (_flatten_df , data_col = namespace ),
163+ partial (_unflatten_df , data_col = namespace ),
164+ namespace = namespace ,
165+ )
166+
167+ optree .register_pytree_node (
168+ pd .Series ,
169+ _flatten_series ,
170+ _unflatten_series ,
171+ namespace = namespace ,
172+ )
173+
174+ optree .register_pytree_node (
175+ np .ndarray ,
176+ _flatten_ndarray ,
177+ _unflatten_ndarray ,
178+ namespace = namespace ,
179+ )
0 commit comments