Skip to content

Commit 4f42b27

Browse files
committed
chore: change jaxlib registration class
1 parent bd29352 commit 4f42b27

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/optimagic/parameters/tree_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
try:
1515
import jax.numpy as jnp # type: ignore[import-not-found]
16+
import jaxlib # type: ignore[import-not-found]
1617

1718
_has_jax = True
1819
except ImportError:
@@ -166,9 +167,8 @@ def _register_namespaces() -> None:
166167
)
167168

168169
if _has_jax:
169-
_jax_array_type = type(jnp.empty(0))
170170
optree.register_pytree_node(
171-
_jax_array_type,
171+
jaxlib._jax.ArrayImpl,
172172
_flatten_jax_array,
173173
_unflatten_jax_array,
174174
namespace=namespace,

0 commit comments

Comments
 (0)