Skip to content

Commit 38db493

Browse files
committed
fix: use jax installation check from config file
1 parent 19b4db3 commit 38db493

1 file changed

Lines changed: 10 additions & 19 deletions

File tree

src/optimagic/parameters/tree_registry.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,13 @@
1010
import pandas as pd
1111
from optree.pytree import PyTreeSpec
1212

13+
from optimagic.config import IS_JAX_INSTALLED
1314
from optimagic.typing import DEFAULT_NAMESPACE, OPTREE_NAMESPACES, PyTree
1415

15-
try:
16+
if IS_JAX_INSTALLED:
1617
import jax.numpy as jnp # type: ignore[import-not-found]
1718
import jaxlib # type: ignore[import-not-found]
1819

19-
_has_jax = True
20-
except ImportError:
21-
_has_jax = False
2220

2321
_are_namespaces_registered = False
2422

@@ -146,7 +144,7 @@ def _get_equality_checkers():
146144
equality_checkers[pd.Series.__name__] = lambda a, b: a.equals(b)
147145
equality_checkers[pd.DataFrame.__name__] = lambda a, b: a.equals(b)
148146

149-
if _has_jax:
147+
if IS_JAX_INSTALLED:
150148
equality_checkers[jnp.ndarray.__name__] = lambda a, b: bool((a == b).all())
151149

152150
return equality_checkers
@@ -193,11 +191,15 @@ def _register_namespaces() -> None:
193191
namespace=namespace,
194192
)
195193

196-
if _has_jax:
194+
if IS_JAX_INSTALLED:
197195
optree.register_pytree_node(
198196
jaxlib._jax.ArrayImpl,
199-
_flatten_jax_array,
200-
_unflatten_jax_array,
197+
lambda arr: (
198+
arr.flatten().tolist(), # type: ignore[attr-defined]
199+
arr.shape, # type: ignore[attr-defined]
200+
_array_element_names(arr), # type: ignore[arg-type]
201+
),
202+
lambda aux_data, leaves: jnp.array(leaves).reshape(aux_data),
201203
namespace=namespace,
202204
)
203205

@@ -254,17 +256,6 @@ def _unflatten_ndarray(aux_data, leaves):
254256
return np.array(leaves).reshape(aux_data)
255257

256258

257-
if _has_jax:
258-
259-
def _flatten_jax_array(arr):
260-
"""Flatten a jax array."""
261-
return arr.flatten().tolist(), arr.shape, _array_element_names(arr)
262-
263-
def _unflatten_jax_array(aux_data, leaves):
264-
"""Unflatten a jax array."""
265-
return jnp.array(leaves).reshape(aux_data)
266-
267-
268259
def _get_df_names(df: pd.DataFrame) -> list[str]:
269260
"""Get string names for dataframe leaf paths."""
270261
index_strings = list(df.index.map(_index_element_to_string))

0 commit comments

Comments
 (0)