Skip to content

Commit 1415d54

Browse files
committed
chore: add tests for tree methods and rearrange tree-registry methods
1 parent ca240b5 commit 1415d54

30 files changed

Lines changed: 370 additions & 186 deletions

pixi.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/estimagic/bootstrap.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
tree_just_flatten,
1919
tree_unflatten,
2020
)
21-
from optimagic.typing import value_namespace
21+
from optimagic.typing import VALUE_NAMESPACE
2222
from optimagic.utilities import get_rng
2323

2424

@@ -108,7 +108,7 @@ def bootstrap(
108108
# ==================================================================================
109109

110110
flat_outcomes = [
111-
tree_just_flatten(_outcome, namespace=value_namespace)
111+
tree_just_flatten(_outcome, namespace=VALUE_NAMESPACE)
112112
for _outcome in all_outcomes
113113
]
114114
internal_outcomes = np.array(flat_outcomes)
@@ -167,10 +167,10 @@ def outcomes(self):
167167
List[Any]: The boostrap outcomes as a list of pytrees.
168168
169169
"""
170-
_, treedef = tree_flatten(self._base_outcome, namespace=value_namespace)
170+
_, treedef = tree_flatten(self._base_outcome, namespace=VALUE_NAMESPACE)
171171

172172
outcomes = [
173-
tree_unflatten(treedef, out, namespace=value_namespace)
173+
tree_unflatten(treedef, out, namespace=VALUE_NAMESPACE)
174174
for out in self._internal_outcomes
175175
]
176176
return outcomes
@@ -186,9 +186,9 @@ def se(self):
186186
cov = self._internal_cov
187187
se = np.sqrt(np.diagonal(cov))
188188

189-
_, treedef = tree_flatten(self._base_outcome, namespace=value_namespace)
189+
_, treedef = tree_flatten(self._base_outcome, namespace=VALUE_NAMESPACE)
190190

191-
se = tree_unflatten(treedef, se, namespace=value_namespace)
191+
se = tree_unflatten(treedef, se, namespace=VALUE_NAMESPACE)
192192
return se
193193

194194
def cov(self, return_type="pytree"):
@@ -209,7 +209,7 @@ def cov(self, return_type="pytree"):
209209
cov = self._internal_cov
210210

211211
if return_type == "dataframe":
212-
names = np.array(leaf_names(self._base_outcome, namespace=value_namespace))
212+
names = np.array(leaf_names(self._base_outcome, namespace=VALUE_NAMESPACE))
213213
cov = pd.DataFrame(cov, columns=names, index=names)
214214
elif return_type == "pytree":
215215
cov = matrix_to_block_tree(cov, self._base_outcome, self._base_outcome)
@@ -237,15 +237,15 @@ def ci(self, ci_method="percentile", ci_level=0.95):
237237
238238
"""
239239
base_outcome_flat, treedef = tree_flatten(
240-
self._base_outcome, namespace=value_namespace
240+
self._base_outcome, namespace=VALUE_NAMESPACE
241241
)
242242

243243
lower_flat, upper_flat = calculate_ci(
244244
base_outcome_flat, self._internal_outcomes, ci_method, ci_level
245245
)
246246

247-
lower = tree_unflatten(treedef, lower_flat, namespace=value_namespace)
248-
upper = tree_unflatten(treedef, upper_flat, namespace=value_namespace)
247+
lower = tree_unflatten(treedef, lower_flat, namespace=VALUE_NAMESPACE)
248+
upper = tree_unflatten(treedef, upper_flat, namespace=VALUE_NAMESPACE)
249249
return lower, upper
250250

251251
def p_values(self):
@@ -274,7 +274,7 @@ def summary(self, ci_method="percentile", ci_level=0.95):
274274
Soon this will be a pytree.
275275
276276
"""
277-
names = leaf_names(self.base_outcome, namespace=value_namespace)
277+
names = leaf_names(self.base_outcome, namespace=VALUE_NAMESPACE)
278278
summary_data = _calulcate_summary_data_bootstrap(
279279
self, ci_method=ci_method, ci_level=ci_level
280280
)

src/estimagic/estimate_msm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from optimagic.shared.check_option_dicts import (
5858
check_optimization_options,
5959
)
60-
from optimagic.typing import value_namespace
60+
from optimagic.typing import VALUE_NAMESPACE
6161
from optimagic.utilities import get_rng, to_pickle
6262

6363

@@ -321,7 +321,7 @@ def func(x):
321321
sim_mom = simulate_moments(params, **simulate_moments_kwargs)
322322
if isinstance(sim_mom, dict) and "simulated_moments" in sim_mom:
323323
sim_mom = sim_mom["simulated_moments"]
324-
out = np.array(tree_just_flatten(sim_mom, namespace=value_namespace))
324+
out = np.array(tree_just_flatten(sim_mom, namespace=VALUE_NAMESPACE))
325325
return out
326326

327327
int_jac = first_derivative(
@@ -420,7 +420,7 @@ def get_msm_optimization_functions(
420420

421421
chol_weights = np.linalg.cholesky(flat_weights)
422422

423-
flat_emp_mom = tree_just_flatten(empirical_moments, namespace=value_namespace)
423+
flat_emp_mom = tree_just_flatten(empirical_moments, namespace=VALUE_NAMESPACE)
424424

425425
_simulate_moments = _partial_kwargs(simulate_moments, simulate_moments_kwargs)
426426
_jacobian = _partial_kwargs(jacobian, jacobian_kwargs)
@@ -431,7 +431,7 @@ def get_msm_optimization_functions(
431431
simulate_moments=_simulate_moments,
432432
flat_empirical_moments=flat_emp_mom,
433433
chol_weights=chol_weights,
434-
namespace=value_namespace,
434+
namespace=VALUE_NAMESPACE,
435435
)
436436
)
437437

@@ -977,7 +977,7 @@ def sensitivity(
977977
)
978978
elif return_type == "dataframe":
979979
row_names = self._internal_estimates.names
980-
col_names = leaf_names(self._empirical_moments, namespace=value_namespace)
980+
col_names = leaf_names(self._empirical_moments, namespace=VALUE_NAMESPACE)
981981
out = pd.DataFrame(
982982
data=raw,
983983
index=row_names,

src/estimagic/msm_weighting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from estimagic.bootstrap import bootstrap
88
from optimagic.parameters.block_trees import block_tree_to_matrix, matrix_to_block_tree
99
from optimagic.parameters.tree_registry import tree_just_flatten
10-
from optimagic.typing import value_namespace
10+
from optimagic.typing import VALUE_NAMESPACE
1111
from optimagic.utilities import robust_inverse
1212

1313

@@ -55,7 +55,7 @@ def get_moments_cov(
5555
def func(data, **kwargs):
5656
raw = calculate_moments(data, **kwargs)
5757
out = pd.Series(
58-
tree_just_flatten(raw, namespace=value_namespace)
58+
tree_just_flatten(raw, namespace=VALUE_NAMESPACE)
5959
) # xxxx won't be necessary soon!
6060
return out
6161

src/estimagic/shared_covs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
tree_just_flatten,
1010
tree_unflatten,
1111
)
12-
from optimagic.typing import value_namespace
12+
from optimagic.typing import VALUE_NAMESPACE
1313

1414

1515
def transform_covariance(
@@ -150,7 +150,7 @@ def calculate_estimation_summary(
150150
# ==================================================================================
151151

152152
flat_data = {
153-
key: tree_just_flatten(val, namespace=value_namespace)
153+
key: tree_just_flatten(val, namespace=VALUE_NAMESPACE)
154154
for key, val in summary_data.items()
155155
}
156156

@@ -169,7 +169,7 @@ def calculate_estimation_summary(
169169
# ==================================================================================
170170

171171
# create tree with values corresponding to indices of df
172-
indices = tree_unflatten(summary_data["value"], names, namespace=value_namespace)
172+
indices = tree_unflatten(summary_data["value"], names, namespace=VALUE_NAMESPACE)
173173

174174
estimates_flat = tree_just_flatten(summary_data["value"])
175175
indices_flat = tree_just_flatten(indices)
@@ -318,7 +318,7 @@ def calculate_free_estimates(estimates, internal_estimates):
318318
mask = internal_estimates.free_mask
319319
names = internal_estimates.names
320320

321-
external_flat = np.array(tree_just_flatten(estimates, namespace=value_namespace))
321+
external_flat = np.array(tree_just_flatten(estimates, namespace=VALUE_NAMESPACE))
322322

323323
free_estimates = FreeParams(
324324
values=external_flat[mask],
@@ -352,7 +352,7 @@ def transform_free_values_to_params_tree(values, free_params, params):
352352
mask = free_params.free_mask
353353
flat = np.full(len(mask), np.nan)
354354
flat[np.ix_(mask)] = values
355-
pytree = tree_unflatten(params, flat, namespace=value_namespace)
355+
pytree = tree_unflatten(params, flat, namespace=VALUE_NAMESPACE)
356356
return pytree
357357

358358

src/optimagic/benchmarking/run_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from optimagic.algorithms import AVAILABLE_ALGORITHMS
1515
from optimagic.optimization.optimize import minimize
1616
from optimagic.parameters.tree_registry import tree_just_flatten
17-
from optimagic.typing import value_namespace
17+
from optimagic.typing import VALUE_NAMESPACE
1818

1919

2020
def run_benchmark(
@@ -190,15 +190,15 @@ def _process_one_result(optimize_result, problem):
190190

191191
# This will happen if the optimization raised an error
192192
if isinstance(optimize_result, str):
193-
params_history_flat = [tree_just_flatten(_start_x, namespace=value_namespace)]
193+
params_history_flat = [tree_just_flatten(_start_x, namespace=VALUE_NAMESPACE)]
194194
criterion_history = [_start_crit_value]
195195
time_history = [np.inf]
196196
batches_history = [0]
197197
else:
198198
history = optimize_result.history
199199
params_history = history.params
200200
params_history_flat = [
201-
tree_just_flatten(p, namespace=value_namespace) for p in params_history
201+
tree_just_flatten(p, namespace=VALUE_NAMESPACE) for p in params_history
202202
]
203203
if _is_noisy:
204204
criterion_history = np.array([_criterion(p) for p in params_history])

src/optimagic/differentiation/derivatives.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
tree_unflatten,
2727
)
2828
from optimagic.parameters.tree_registry import tree_just_flatten as tree_leaves
29-
from optimagic.typing import BatchEvaluatorLiteral, PyTree, value_namespace
29+
from optimagic.typing import VALUE_NAMESPACE, BatchEvaluatorLiteral, PyTree
3030

3131

3232
@dataclass(frozen=True)
@@ -221,22 +221,22 @@ def first_derivative(
221221
is_fast_path = _is_1d_array(params)
222222

223223
if not is_fast_path:
224-
x, params_treedef = tree_flatten(params, namespace=value_namespace)
224+
x, params_treedef = tree_flatten(params, namespace=VALUE_NAMESPACE)
225225
x = np.array(x, dtype=np.float64)
226226

227227
if scaling_factor is not None and not np.isscalar(scaling_factor):
228228
scaling_factor = np.array(
229-
tree_just_flatten(scaling_factor, namespace=value_namespace)
229+
tree_just_flatten(scaling_factor, namespace=VALUE_NAMESPACE)
230230
)
231231

232232
if min_steps is not None and not np.isscalar(min_steps):
233233
min_steps = np.array(
234-
tree_just_flatten(min_steps, namespace=value_namespace)
234+
tree_just_flatten(min_steps, namespace=VALUE_NAMESPACE)
235235
)
236236

237237
if step_size is not None and not np.isscalar(step_size):
238238
step_size = np.array(
239-
tree_just_flatten(step_size, namespace=value_namespace)
239+
tree_just_flatten(step_size, namespace=VALUE_NAMESPACE)
240240
)
241241
else:
242242
x = params.astype(np.float64)
@@ -291,7 +291,7 @@ def first_derivative(
291291
if not is_fast_path:
292292
evaluation_points = [
293293
# entries are either a numpy.ndarray or np.nan
294-
_unflatten_if_not_nan(p, params_treedef, value_namespace)
294+
_unflatten_if_not_nan(p, params_treedef, VALUE_NAMESPACE)
295295
for p in evaluation_points
296296
]
297297

@@ -330,14 +330,14 @@ def first_derivative(
330330
elif vector_out:
331331
f0 = f0_tree.astype(float)
332332
else:
333-
f0 = tree_leaves(f0_tree, namespace=value_namespace)
333+
f0 = tree_leaves(f0_tree, namespace=VALUE_NAMESPACE)
334334
f0 = np.array(f0, dtype=np.float64)
335335

336336
# convert the raw evaluations to numpy arrays
337337
raw_evals_arr = _convert_evals_to_numpy(
338338
raw_evals=raw_evals,
339339
unpacker=unpacker,
340-
namespace=value_namespace,
340+
namespace=VALUE_NAMESPACE,
341341
is_scalar_out=scalar_out,
342342
is_vector_out=vector_out,
343343
)
@@ -539,22 +539,22 @@ def second_derivative(
539539
is_fast_path = _is_1d_array(params)
540540

541541
if not is_fast_path:
542-
x, params_treedef = tree_flatten(params, namespace=value_namespace)
542+
x, params_treedef = tree_flatten(params, namespace=VALUE_NAMESPACE)
543543
x = np.array(x, dtype=np.float64)
544544

545545
if scaling_factor is not None and not np.isscalar(scaling_factor):
546546
scaling_factor = np.array(
547-
tree_just_flatten(scaling_factor, namespace=value_namespace)
547+
tree_just_flatten(scaling_factor, namespace=VALUE_NAMESPACE)
548548
)
549549

550550
if min_steps is not None and not np.isscalar(min_steps):
551551
min_steps = np.array(
552-
tree_just_flatten(min_steps, namespace=value_namespace)
552+
tree_just_flatten(min_steps, namespace=VALUE_NAMESPACE)
553553
)
554554

555555
if step_size is not None and not np.isscalar(step_size):
556556
step_size = np.array(
557-
tree_just_flatten(step_size, namespace=value_namespace)
557+
tree_just_flatten(step_size, namespace=VALUE_NAMESPACE)
558558
)
559559
else:
560560
x = params.astype(np.float64)
@@ -631,7 +631,7 @@ def second_derivative(
631631
evaluation_points = {
632632
# entries are either a numpy.ndarray or np.nan, we unflatten only
633633
step_type: [
634-
_unflatten_if_not_nan(p, params_treedef, value_namespace)
634+
_unflatten_if_not_nan(p, params_treedef, VALUE_NAMESPACE)
635635
for p in points
636636
]
637637
for step_type, points in evaluation_points.items()
@@ -671,13 +671,13 @@ def second_derivative(
671671
func_value = f0
672672

673673
f0_tree = unpacker(f0)
674-
f0 = tree_leaves(f0_tree, namespace=value_namespace)
674+
f0 = tree_leaves(f0_tree, namespace=VALUE_NAMESPACE)
675675
f0 = np.array(f0, dtype=np.float64)
676676

677677
# convert the raw evaluations to numpy arrays
678678
raw_evals = {
679679
step_type: _convert_evals_to_numpy(
680-
raw_evals=evals, unpacker=unpacker, namespace=value_namespace
680+
raw_evals=evals, unpacker=unpacker, namespace=VALUE_NAMESPACE
681681
)
682682
for step_type, evals in raw_evals.items()
683683
}

src/optimagic/examples/criterion_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
tree_just_flatten,
2121
tree_unflatten,
2222
)
23-
from optimagic.typing import PyTree, value_namespace
23+
from optimagic.typing import VALUE_NAMESPACE, PyTree
2424

2525

2626
@mark.scalar
@@ -215,11 +215,11 @@ def _get_x(params: PyTree) -> NDArray[np.float64]:
215215
x = params.astype(float)
216216
else:
217217
x = np.array(
218-
tree_just_flatten(params, namespace=value_namespace), dtype=np.float64
218+
tree_just_flatten(params, namespace=VALUE_NAMESPACE), dtype=np.float64
219219
)
220220
return x
221221

222222

223223
def _unflatten_gradient(flat: NDArray[np.float64], params: PyTree) -> PyTree:
224-
out = tree_unflatten(params, flat.tolist(), namespace=value_namespace)
224+
out = tree_unflatten(params, flat.tolist(), namespace=VALUE_NAMESPACE)
225225
return out

src/optimagic/optimization/fun_value.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from optimagic.exceptions import InvalidFunctionError
1010
from optimagic.parameters.tree_registry import tree_just_flatten
11-
from optimagic.typing import AggregationLevel, PyTree, Scalar, value_namespace
11+
from optimagic.typing import VALUE_NAMESPACE, AggregationLevel, PyTree, Scalar
1212
from optimagic.utilities import isscalar
1313

1414

@@ -123,7 +123,7 @@ def _get_flat_value(value: PyTree) -> NDArray[np.float64]:
123123
elif isinstance(value, np.ndarray):
124124
flat = value.flatten()
125125
else:
126-
flat = tree_just_flatten(value, namespace=value_namespace)
126+
flat = tree_just_flatten(value, namespace=VALUE_NAMESPACE)
127127

128128
flat_arr = np.asarray(flat, dtype=np.float64)
129129
return flat_arr

0 commit comments

Comments
 (0)