Skip to content

Commit e3eb382

Browse files
committed
chore: replace tree_equal method with optree impl
1 parent 109623f commit e3eb382

5 files changed

Lines changed: 40 additions & 7 deletions

File tree

src/optimagic/parameters/tree_registry.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,35 @@ def _unflatten_df_optree(aux_data, leaves):
185185
lambda aux_data, leaves: np.array(leaves).reshape(aux_data),
186186
namespace=extended_namespace,
187187
)
188+
189+
EQUALITY_CHECKERS = {}
190+
EQUALITY_CHECKERS[np.ndarray] = lambda a, b: bool((a == b).all())
191+
EQUALITY_CHECKERS[pd.Series] = lambda a, b: a.equals(b)
192+
EQUALITY_CHECKERS[pd.DataFrame] = lambda a, b: a.equals(b)
193+
194+
195+
def tree_equal(tree, other, is_leaf=None, registry=None, equality_checkers=None):
196+
equality_checkers = (
197+
EQUALITY_CHECKERS
198+
if equality_checkers is None
199+
else {**EQUALITY_CHECKERS, **equality_checkers}
200+
)
201+
202+
first_flat, first_treespec = tree_flatten(tree, is_leaf=is_leaf, registry=registry)
203+
second_flat, second_treespec = tree_flatten(
204+
other, is_leaf=is_leaf, registry=registry
205+
)
206+
207+
first_names = leaf_names(tree, is_leaf=is_leaf, registry=registry)
208+
second_names = leaf_names(tree, is_leaf=is_leaf, registry=registry)
209+
210+
equal = first_names == second_names and first_treespec == second_treespec
211+
212+
if equal:
213+
for first, second in zip(first_flat, second_flat, strict=True):
214+
check_func = equality_checkers.get(type(first), lambda a, b: a == b)
215+
equal = equal and check_func(first, second)
216+
if not equal:
217+
break
218+
219+
return equal

tests/estimagic/test_shared.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pandas as pd
55
import pytest
66
from numpy.testing import assert_array_almost_equal as aaae
7-
from pybaum import tree_equal
87

98
from estimagic.shared_covs import (
109
_to_numpy,
@@ -15,7 +14,7 @@
1514
transform_free_cov_to_cov,
1615
transform_free_values_to_params_tree,
1716
)
18-
from optimagic.parameters.tree_registry import get_registry, leaf_names
17+
from optimagic.parameters.tree_registry import get_registry, leaf_names, tree_equal
1918
from optimagic.utilities import get_rng
2019

2120

tests/optimagic/differentiation/test_compare_derivatives_with_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
import numpy as np
88
import pytest
99
from numpy.testing import assert_array_almost_equal as aaae
10-
from pybaum import tree_equal
1110

1211
from optimagic.config import IS_JAX_INSTALLED
1312
from optimagic.differentiation.derivatives import first_derivative, second_derivative
13+
from optimagic.parameters.tree_registry import tree_equal
1414

1515
if not IS_JAX_INSTALLED:
1616
pytestmark = pytest.mark.skip(reason="jax is not installed.")

tests/optimagic/logging/test_logger.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44
import pandas as pd
55
import pytest
6-
from pybaum import tree_equal
76

87
from optimagic.logging.logger import (
98
LogOptions,
@@ -13,7 +12,11 @@
1312
SQLiteLogReader,
1413
)
1514
from optimagic.optimization.optimize import minimize
16-
from optimagic.parameters.tree_registry import get_registry, tree_just_flatten
15+
from optimagic.parameters.tree_registry import (
16+
get_registry,
17+
tree_equal,
18+
tree_just_flatten,
19+
)
1720
from optimagic.typing import Direction
1821

1922

tests/optimagic/parameters/test_block_trees.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import pandas as pd
33
import pytest
44
from numpy.testing import assert_array_equal
5-
from pybaum import tree_equal
65

76
from optimagic import second_derivative
87
from optimagic.parameters.block_trees import (
@@ -11,7 +10,7 @@
1110
hessian_to_block_tree,
1211
matrix_to_block_tree,
1312
)
14-
from optimagic.parameters.tree_registry import get_registry
13+
from optimagic.parameters.tree_registry import get_registry, tree_equal
1514
from optimagic.parameters.tree_registry import tree_just_flatten as tree_leaves
1615

1716

0 commit comments

Comments
 (0)