Skip to content

Commit 5aa2b08

Browse files
committed
chore: remove get_registry method and use namespace arugment
1 parent e3eb382 commit 5aa2b08

28 files changed

Lines changed: 176 additions & 284 deletions

src/estimagic/bootstrap.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from optimagic.batch_evaluators import joblib_batch_evaluator
1414
from optimagic.parameters.block_trees import matrix_to_block_tree
1515
from optimagic.parameters.tree_registry import (
16-
get_registry,
16+
extended,
1717
leaf_names,
1818
tree_flatten,
1919
tree_just_flatten,
@@ -107,9 +107,8 @@ def bootstrap(
107107
# Process results
108108
# ==================================================================================
109109

110-
registry = get_registry(extended=True)
111110
flat_outcomes = [
112-
tree_just_flatten(_outcome, registry=registry) for _outcome in all_outcomes
111+
tree_just_flatten(_outcome, namespace=extended) for _outcome in all_outcomes
113112
]
114113
internal_outcomes = np.array(flat_outcomes)
115114

@@ -167,11 +166,10 @@ def outcomes(self):
167166
List[Any]: The boostrap outcomes as a list of pytrees.
168167
169168
"""
170-
registry = get_registry(extended=True)
171-
_, treedef = tree_flatten(self._base_outcome, registry=registry)
169+
_, treedef = tree_flatten(self._base_outcome, namespace=extended)
172170

173171
outcomes = [
174-
tree_unflatten(treedef, out, registry=registry)
172+
tree_unflatten(treedef, out, namespace=extended)
175173
for out in self._internal_outcomes
176174
]
177175
return outcomes
@@ -187,10 +185,9 @@ def se(self):
187185
cov = self._internal_cov
188186
se = np.sqrt(np.diagonal(cov))
189187

190-
registry = get_registry(extended=True)
191-
_, treedef = tree_flatten(self._base_outcome, registry=registry)
188+
_, treedef = tree_flatten(self._base_outcome, namespace=extended)
192189

193-
se = tree_unflatten(treedef, se, registry=registry)
190+
se = tree_unflatten(treedef, se, namespace=extended)
194191
return se
195192

196193
def cov(self, return_type="pytree"):
@@ -211,8 +208,7 @@ def cov(self, return_type="pytree"):
211208
cov = self._internal_cov
212209

213210
if return_type == "dataframe":
214-
registry = get_registry(extended=True)
215-
names = np.array(leaf_names(self._base_outcome, registry=registry))
211+
names = np.array(leaf_names(self._base_outcome, namespace=extended))
216212
cov = pd.DataFrame(cov, columns=names, index=names)
217213
elif return_type == "pytree":
218214
cov = matrix_to_block_tree(cov, self._base_outcome, self._base_outcome)
@@ -239,15 +235,16 @@ def ci(self, ci_method="percentile", ci_level=0.95):
239235
bounds of confidence intervals.
240236
241237
"""
242-
registry = get_registry(extended=True)
243-
base_outcome_flat, treedef = tree_flatten(self._base_outcome, registry=registry)
238+
base_outcome_flat, treedef = tree_flatten(
239+
self._base_outcome, namespace=extended
240+
)
244241

245242
lower_flat, upper_flat = calculate_ci(
246243
base_outcome_flat, self._internal_outcomes, ci_method, ci_level
247244
)
248245

249-
lower = tree_unflatten(treedef, lower_flat, registry=registry)
250-
upper = tree_unflatten(treedef, upper_flat, registry=registry)
246+
lower = tree_unflatten(treedef, lower_flat, namespace=extended)
247+
upper = tree_unflatten(treedef, upper_flat, namespace=extended)
251248
return lower, upper
252249

253250
def p_values(self):
@@ -276,8 +273,7 @@ def summary(self, ci_method="percentile", ci_level=0.95):
276273
Soon this will be a pytree.
277274
278275
"""
279-
registry = get_registry(extended=True)
280-
names = leaf_names(self.base_outcome, registry=registry)
276+
names = leaf_names(self.base_outcome, namespace=extended)
281277
summary_data = _calulcate_summary_data_bootstrap(
282278
self, ci_method=ci_method, ci_level=ci_level
283279
)

src/estimagic/estimate_msm.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from optimagic.parameters.conversion import Converter, get_converter
5252
from optimagic.parameters.space_conversion import InternalParams
5353
from optimagic.parameters.tree_registry import (
54-
get_registry,
54+
extended,
5555
leaf_names,
5656
tree_just_flatten,
5757
)
@@ -321,8 +321,7 @@ def func(x):
321321
sim_mom = simulate_moments(params, **simulate_moments_kwargs)
322322
if isinstance(sim_mom, dict) and "simulated_moments" in sim_mom:
323323
sim_mom = sim_mom["simulated_moments"]
324-
registry = get_registry(extended=True)
325-
out = np.array(tree_just_flatten(sim_mom, registry=registry))
324+
out = np.array(tree_just_flatten(sim_mom, namespace=extended))
326325
return out
327326

328327
int_jac = first_derivative(
@@ -421,8 +420,7 @@ def get_msm_optimization_functions(
421420

422421
chol_weights = np.linalg.cholesky(flat_weights)
423422

424-
registry = get_registry(extended=True)
425-
flat_emp_mom = tree_just_flatten(empirical_moments, registry=registry)
423+
flat_emp_mom = tree_just_flatten(empirical_moments, namespace=extended)
426424

427425
_simulate_moments = _partial_kwargs(simulate_moments, simulate_moments_kwargs)
428426
_jacobian = _partial_kwargs(jacobian, jacobian_kwargs)
@@ -433,7 +431,7 @@ def get_msm_optimization_functions(
433431
simulate_moments=_simulate_moments,
434432
flat_empirical_moments=flat_emp_mom,
435433
chol_weights=chol_weights,
436-
registry=registry,
434+
namespace=extended,
437435
)
438436
)
439437

@@ -448,7 +446,7 @@ def get_msm_optimization_functions(
448446

449447

450448
def _msm_criterion(
451-
params, simulate_moments, flat_empirical_moments, chol_weights, registry
449+
params, simulate_moments, flat_empirical_moments, chol_weights, namespace
452450
):
453451
"""Calculate msm criterion given parameters and building blocks."""
454452
simulated = simulate_moments(params)
@@ -457,7 +455,7 @@ def _msm_criterion(
457455
if isinstance(simulated, np.ndarray) and simulated.ndim == 1:
458456
simulated_flat = simulated
459457
else:
460-
simulated_flat = np.array(tree_just_flatten(simulated, registry=registry))
458+
simulated_flat = np.array(tree_just_flatten(simulated, namespace=namespace))
461459

462460
deviations = simulated_flat - flat_empirical_moments
463461
residuals = deviations @ chol_weights
@@ -978,9 +976,8 @@ def sensitivity(
978976
inner_tree=self._empirical_moments,
979977
)
980978
elif return_type == "dataframe":
981-
registry = get_registry(extended=True)
982979
row_names = self._internal_estimates.names
983-
col_names = leaf_names(self._empirical_moments, registry=registry)
980+
col_names = leaf_names(self._empirical_moments, namespace=extended)
984981
out = pd.DataFrame(
985982
data=raw,
986983
index=row_names,

src/estimagic/msm_weighting.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from estimagic.bootstrap import bootstrap
88
from optimagic.parameters.block_trees import block_tree_to_matrix, matrix_to_block_tree
9-
from optimagic.parameters.tree_registry import get_registry, tree_just_flatten
9+
from optimagic.parameters.tree_registry import extended, tree_just_flatten
1010
from optimagic.utilities import robust_inverse
1111

1212

@@ -50,13 +50,11 @@ def get_moments_cov(
5050

5151
first_eval = calculate_moments(data, **moment_kwargs)
5252

53-
registry = get_registry(extended=True)
54-
5553
@functools.wraps(calculate_moments)
5654
def func(data, **kwargs):
5755
raw = calculate_moments(data, **kwargs)
5856
out = pd.Series(
59-
tree_just_flatten(raw, registry=registry)
57+
tree_just_flatten(raw, namespace=extended)
6058
) # xxxx won't be necessary soon!
6159
return out
6260

src/estimagic/shared_covs.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from optimagic.parameters.block_trees import matrix_to_block_tree
88
from optimagic.parameters.tree_registry import (
9-
get_registry,
9+
extended,
1010
tree_just_flatten,
1111
tree_unflatten,
1212
)
@@ -149,9 +149,8 @@ def calculate_estimation_summary(
149149
# Flatten summary and construct data frame for flat estimates
150150
# ==================================================================================
151151

152-
registry = get_registry(extended=True)
153152
flat_data = {
154-
key: tree_just_flatten(val, registry=registry)
153+
key: tree_just_flatten(val, namespace=extended)
155154
for key, val in summary_data.items()
156155
}
157156

@@ -170,7 +169,7 @@ def calculate_estimation_summary(
170169
# ==================================================================================
171170

172171
# create tree with values corresponding to indices of df
173-
indices = tree_unflatten(summary_data["value"], names, registry=registry)
172+
indices = tree_unflatten(summary_data["value"], names, namespace=extended)
174173

175174
estimates_flat = tree_just_flatten(summary_data["value"])
176175
indices_flat = tree_just_flatten(indices)
@@ -319,8 +318,7 @@ def calculate_free_estimates(estimates, internal_estimates):
319318
mask = internal_estimates.free_mask
320319
names = internal_estimates.names
321320

322-
registry = get_registry(extended=True)
323-
external_flat = np.array(tree_just_flatten(estimates, registry=registry))
321+
external_flat = np.array(tree_just_flatten(estimates, namespace=extended))
324322

325323
free_estimates = FreeParams(
326324
values=external_flat[mask],
@@ -354,8 +352,7 @@ def transform_free_values_to_params_tree(values, free_params, params):
354352
mask = free_params.free_mask
355353
flat = np.full(len(mask), np.nan)
356354
flat[np.ix_(mask)] = values
357-
registry = get_registry(extended=True)
358-
pytree = tree_unflatten(params, flat, registry=registry)
355+
pytree = tree_unflatten(params, flat, namespace=extended)
359356
return pytree
360357

361358

src/optimagic/benchmarking/run_benchmark.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from optimagic import batch_evaluators
1414
from optimagic.algorithms import AVAILABLE_ALGORITHMS
1515
from optimagic.optimization.optimize import minimize
16-
from optimagic.parameters.tree_registry import get_registry, tree_just_flatten
16+
from optimagic.parameters.tree_registry import extended, tree_just_flatten
1717

1818

1919
def run_benchmark(
@@ -179,7 +179,6 @@ def _process_one_result(optimize_result, problem):
179179
dict: Processed result.
180180
181181
"""
182-
_registry = get_registry(extended=True)
183182
_criterion = problem["noise_free_fun"]
184183
_start_x = problem["inputs"]["params"]
185184
_start_crit_value = _criterion(_start_x)
@@ -190,15 +189,15 @@ def _process_one_result(optimize_result, problem):
190189

191190
# This will happen if the optimization raised an error
192191
if isinstance(optimize_result, str):
193-
params_history_flat = [tree_just_flatten(_start_x, registry=_registry)]
192+
params_history_flat = [tree_just_flatten(_start_x, namespace=extended)]
194193
criterion_history = [_start_crit_value]
195194
time_history = [np.inf]
196195
batches_history = [0]
197196
else:
198197
history = optimize_result.history
199198
params_history = history.params
200199
params_history_flat = [
201-
tree_just_flatten(p, registry=_registry) for p in params_history
200+
tree_just_flatten(p, namespace=extended) for p in params_history
202201
]
203202
if _is_noisy:
204203
criterion_history = np.array([_criterion(p) for p in params_history])

0 commit comments

Comments
 (0)