We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent bd29352 commit 4f42b27Copy full SHA for 4f42b27
1 file changed
src/optimagic/parameters/tree_registry.py
@@ -13,6 +13,7 @@
13
14
try:
15
import jax.numpy as jnp # type: ignore[import-not-found]
16
+ import jaxlib # type: ignore[import-not-found]
17
18
_has_jax = True
19
except ImportError:
@@ -166,9 +167,8 @@ def _register_namespaces() -> None:
166
167
)
168
169
if _has_jax:
- _jax_array_type = type(jnp.empty(0))
170
optree.register_pytree_node(
171
- _jax_array_type,
+ jaxlib._jax.ArrayImpl,
172
_flatten_jax_array,
173
_unflatten_jax_array,
174
namespace=namespace,
0 commit comments