Skip to content

Commit b9e32d7

Browse files
committed
chore: add type hints
1 parent 4f42b27 commit b9e32d7

3 files changed

Lines changed: 48 additions & 17 deletions

File tree

src/optimagic/differentiation/derivatives.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ def first_derivative(
220220
is_fast_path = _is_1d_array(params)
221221

222222
if not is_fast_path:
223-
x, params_treedef = tree_flatten(params, namespace=VALUE_NAMESPACE)
224-
x = np.array(x, dtype=np.float64)
223+
params_leaves, params_treedef = tree_flatten(params, namespace=VALUE_NAMESPACE)
224+
x = np.array(params_leaves, dtype=np.float64)
225225

226226
if scaling_factor is not None and not np.isscalar(scaling_factor):
227227
scaling_factor = np.array(
@@ -272,7 +272,7 @@ def first_derivative(
272272
step_size = cast(NDArray[np.float64], step_size)
273273

274274
# generate parameter vectors at which func has to be evaluated as numpy arrays
275-
evaluation_points = []
275+
evaluation_points: list[float | np.ndarray] = []
276276
for step_arr in step_size:
277277
for i, j in product(range(n_steps), range(len(x))):
278278
if np.isnan(step_arr[i, j]):
@@ -534,8 +534,8 @@ def second_derivative(
534534
is_fast_path = _is_1d_array(params)
535535

536536
if not is_fast_path:
537-
x, params_treedef = tree_flatten(params, namespace=VALUE_NAMESPACE)
538-
x = np.array(x, dtype=np.float64)
537+
params_leaves, params_treedef = tree_flatten(params, namespace=VALUE_NAMESPACE)
538+
x = np.array(params_leaves, dtype=np.float64)
539539

540540
if scaling_factor is not None and not np.isscalar(scaling_factor):
541541
scaling_factor = np.array(

src/optimagic/optimization/fun_value.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def _get_flat_value(value: PyTree) -> NDArray[np.float64]:
123123
elif isinstance(value, np.ndarray):
124124
flat = value.flatten()
125125
else:
126-
flat = tree_leaves(value, namespace=VALUE_NAMESPACE)
126+
value_leaves = tree_leaves(value, namespace=VALUE_NAMESPACE)
127+
flat = np.asarray(value_leaves, dtype=np.float64)
127128

128129
flat_arr = np.asarray(flat, dtype=np.float64)
129130
return flat_arr

src/optimagic/parameters/tree_registry.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import warnings
44
from functools import partial
55
from itertools import product
6+
from typing import Any, Callable, Iterable
67

78
import numpy as np
89
import optree
910
import pandas as pd
1011
from optree.pytree import PyTreeSpec
1112

12-
from optimagic.typing import DEFAULT_NAMESPACE, OPTREE_NAMESPACES
13+
from optimagic.typing import DEFAULT_NAMESPACE, OPTREE_NAMESPACES, PyTree
1314

1415
try:
1516
import jax.numpy as jnp # type: ignore[import-not-found]
@@ -22,23 +23,35 @@
2223
_are_namespaces_registered = False
2324

2425

25-
def tree_flatten(tree, is_leaf=None, namespace=DEFAULT_NAMESPACE):
26+
def tree_flatten(
27+
tree: PyTree,
28+
is_leaf: Callable[[PyTree], bool] | None = None,
29+
namespace: str = DEFAULT_NAMESPACE,
30+
) -> tuple[list, PyTreeSpec]:
2631
"""Flatten a pytree."""
2732
_register_namespaces()
2833
_check_namespace(namespace)
2934
with optree.dict_insertion_ordered(True, namespace=namespace):
3035
return optree.tree_flatten(tree, is_leaf=is_leaf, namespace=namespace)
3136

3237

33-
def tree_leaves(tree, is_leaf=None, namespace=DEFAULT_NAMESPACE):
38+
def tree_leaves(
39+
tree: PyTree,
40+
is_leaf: Callable[[PyTree], bool] | None = None,
41+
namespace: str = DEFAULT_NAMESPACE,
42+
) -> list:
3443
"""Get the leaves of a pytree."""
3544
_register_namespaces()
3645
_check_namespace(namespace)
3746
with optree.dict_insertion_ordered(True, namespace=namespace):
3847
return optree.tree_leaves(tree, is_leaf=is_leaf, namespace=namespace)
3948

4049

41-
def tree_unflatten(treedef, leaves, namespace=DEFAULT_NAMESPACE):
50+
def tree_unflatten(
51+
treedef: PyTree | PyTreeSpec,
52+
leaves: Iterable,
53+
namespace: str = DEFAULT_NAMESPACE,
54+
) -> PyTree:
4255
"""Reconstruct a pytree from the tree definition and the leaves."""
4356
_register_namespaces()
4457

@@ -52,7 +65,12 @@ def tree_unflatten(treedef, leaves, namespace=DEFAULT_NAMESPACE):
5265
return optree.tree_unflatten(treedef, leaves)
5366

5467

55-
def tree_map(func, tree, is_leaf=None, namespace=DEFAULT_NAMESPACE):
68+
def tree_map(
69+
func: Callable[[PyTree], PyTree],
70+
tree: PyTree,
71+
is_leaf: Callable[[PyTree], bool] | None = None,
72+
namespace: str = DEFAULT_NAMESPACE,
73+
) -> PyTree:
5674
"""Map an input function over pytree args to produce a new pytree."""
5775
_register_namespaces()
5876
_check_namespace(namespace)
@@ -62,7 +80,12 @@ def tree_map(func, tree, is_leaf=None, namespace=DEFAULT_NAMESPACE):
6280
return optree.tree_map(func, tree, is_leaf=is_leaf, namespace=namespace)
6381

6482

65-
def leaf_names(tree, is_leaf=None, namespace=DEFAULT_NAMESPACE, separator="_"):
83+
def leaf_names(
84+
tree: PyTree,
85+
is_leaf: Callable[[PyTree], bool] | None = None,
86+
namespace: str = DEFAULT_NAMESPACE,
87+
separator: str = "_",
88+
) -> list[str]:
6689
"""Get the path names for tree leaves."""
6790
_register_namespaces()
6891
_check_namespace(namespace)
@@ -75,8 +98,12 @@ def leaf_names(tree, is_leaf=None, namespace=DEFAULT_NAMESPACE, separator="_"):
7598

7699

77100
def tree_equal(
78-
tree, other, is_leaf=None, namespace=DEFAULT_NAMESPACE, equality_checkers=None
79-
):
101+
tree: PyTree,
102+
other: PyTree,
103+
is_leaf: Callable[[PyTree], bool] | None = None,
104+
namespace: str = DEFAULT_NAMESPACE,
105+
equality_checkers: dict[str, Callable[[Any, Any], bool]] | None = None,
106+
) -> bool:
80107
"""Check the equality between two trees."""
81108
equality_checkers = (
82109
_get_equality_checkers()
@@ -238,7 +265,8 @@ def _unflatten_jax_array(aux_data, leaves):
238265
return jnp.array(leaves).reshape(aux_data)
239266

240267

241-
def _get_df_names(df):
268+
def _get_df_names(df: pd.DataFrame) -> list[str]:
269+
"""Get string names for dataframe leaf paths."""
242270
index_strings = list(df.index.map(_index_element_to_string))
243271
if "value" in df:
244272
out = index_strings
@@ -248,7 +276,8 @@ def _get_df_names(df):
248276
return out
249277

250278

251-
def _index_element_to_string(element):
279+
def _index_element_to_string(element: Any) -> str:
280+
"""Convert an index element to its string representation."""
252281
if isinstance(element, (tuple, list)):
253282
as_strings = [str(entry) for entry in element]
254283
res_string = "_".join(as_strings)
@@ -258,7 +287,8 @@ def _index_element_to_string(element):
258287
return res_string
259288

260289

261-
def _array_element_names(arr):
290+
def _array_element_names(arr: np.ndarray) -> list[str]:
291+
"""Get string names for array like element leaf paths."""
262292
dim_names = [map(str, range(n)) for n in arr.shape]
263293
names = list(map("_".join, product(*dim_names)))
264294
return names

0 commit comments

Comments
 (0)