Skip to content

Commit bd29352

Browse files
committed
fix: use jax array type for registering a jax array
1 parent 651fa5e commit bd29352

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/optimagic/parameters/tree_registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,9 @@ def _register_namespaces() -> None:
166166
)
167167

168168
if _has_jax:
169+
_jax_array_type = type(jnp.empty(0))
169170
optree.register_pytree_node(
170-
jnp.ndarray,
171+
_jax_array_type,
171172
_flatten_jax_array,
172173
_unflatten_jax_array,
173174
namespace=namespace,

0 commit comments

Comments
 (0)