|
10 | 10 | import pandas as pd |
11 | 11 | from optree.pytree import PyTreeSpec |
12 | 12 |
|
| 13 | +from optimagic.config import IS_JAX_INSTALLED |
13 | 14 | from optimagic.typing import DEFAULT_NAMESPACE, OPTREE_NAMESPACES, PyTree |
14 | 15 |
|
15 | | -try: |
| 16 | +if IS_JAX_INSTALLED: |
16 | 17 | import jax.numpy as jnp # type: ignore[import-not-found] |
17 | 18 | import jaxlib # type: ignore[import-not-found] |
18 | 19 |
|
19 | | - _has_jax = True |
20 | | -except ImportError: |
21 | | - _has_jax = False |
22 | 20 |
|
23 | 21 | _are_namespaces_registered = False |
24 | 22 |
|
@@ -146,7 +144,7 @@ def _get_equality_checkers(): |
146 | 144 | equality_checkers[pd.Series.__name__] = lambda a, b: a.equals(b) |
147 | 145 | equality_checkers[pd.DataFrame.__name__] = lambda a, b: a.equals(b) |
148 | 146 |
|
149 | | - if _has_jax: |
| 147 | + if IS_JAX_INSTALLED: |
150 | 148 | equality_checkers[jnp.ndarray.__name__] = lambda a, b: bool((a == b).all()) |
151 | 149 |
|
152 | 150 | return equality_checkers |
@@ -193,11 +191,15 @@ def _register_namespaces() -> None: |
193 | 191 | namespace=namespace, |
194 | 192 | ) |
195 | 193 |
|
196 | | - if _has_jax: |
| 194 | + if IS_JAX_INSTALLED: |
197 | 195 | optree.register_pytree_node( |
198 | 196 | jaxlib._jax.ArrayImpl, |
199 | | - _flatten_jax_array, |
200 | | - _unflatten_jax_array, |
| 197 | + lambda arr: ( |
| 198 | + arr.flatten().tolist(), # type: ignore[attr-defined] |
| 199 | + arr.shape, # type: ignore[attr-defined] |
| 200 | + _array_element_names(arr), # type: ignore[arg-type] |
| 201 | + ), |
| 202 | + lambda aux_data, leaves: jnp.array(leaves).reshape(aux_data), |
201 | 203 | namespace=namespace, |
202 | 204 | ) |
203 | 205 |
|
@@ -254,17 +256,6 @@ def _unflatten_ndarray(aux_data, leaves): |
254 | 256 | return np.array(leaves).reshape(aux_data) |
255 | 257 |
|
256 | 258 |
|
257 | | -if _has_jax: |
258 | | - |
259 | | - def _flatten_jax_array(arr): |
260 | | - """Flatten a jax array.""" |
261 | | - return arr.flatten().tolist(), arr.shape, _array_element_names(arr) |
262 | | - |
263 | | - def _unflatten_jax_array(aux_data, leaves): |
264 | | - """Unflatten a jax array.""" |
265 | | - return jnp.array(leaves).reshape(aux_data) |
266 | | - |
267 | | - |
268 | 259 | def _get_df_names(df: pd.DataFrame) -> list[str]: |
269 | 260 | """Get string names for dataframe leaf paths.""" |
270 | 261 | index_strings = list(df.index.map(_index_element_to_string)) |
|
0 commit comments