@@ -185,3 +185,35 @@ def _unflatten_df_optree(aux_data, leaves):
185185 lambda aux_data , leaves : np .array (leaves ).reshape (aux_data ),
186186 namespace = extended_namespace ,
187187)
188+
189+ EQUALITY_CHECKERS = {}
190+ EQUALITY_CHECKERS [np .ndarray ] = lambda a , b : bool ((a == b ).all ())
191+ EQUALITY_CHECKERS [pd .Series ] = lambda a , b : a .equals (b )
192+ EQUALITY_CHECKERS [pd .DataFrame ] = lambda a , b : a .equals (b )
193+
194+
195+ def tree_equal (tree , other , is_leaf = None , registry = None , equality_checkers = None ):
196+ equality_checkers = (
197+ EQUALITY_CHECKERS
198+ if equality_checkers is None
199+ else {** EQUALITY_CHECKERS , ** equality_checkers }
200+ )
201+
202+ first_flat , first_treespec = tree_flatten (tree , is_leaf = is_leaf , registry = registry )
203+ second_flat , second_treespec = tree_flatten (
204+ other , is_leaf = is_leaf , registry = registry
205+ )
206+
207+ first_names = leaf_names (tree , is_leaf = is_leaf , registry = registry )
208+ second_names = leaf_names (tree , is_leaf = is_leaf , registry = registry )
209+
210+ equal = first_names == second_names and first_treespec == second_treespec
211+
212+ if equal :
213+ for first , second in zip (first_flat , second_flat , strict = True ):
214+ check_func = equality_checkers .get (type (first ), lambda a , b : a == b )
215+ equal = equal and check_func (first , second )
216+ if not equal :
217+ break
218+
219+ return equal
0 commit comments