Skip to content

Commit 7f0001c

Browse files
jpbrodrick89patrick-kidger
authored andcommitted
make diffrax compatible with new lineax pre-release
1 parent 0f77ba0 commit 7f0001c

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

diffrax/_term.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,8 @@ class WrapTerm(AbstractTerm[_VF, _Control]):
736736
direction: IntScalarLike
737737

738738
def vf(self, t: RealScalarLike, y: Y, args: Args) -> _VF:
739-
t = t * self.direction
739+
with jax.numpy_dtype_promotion("standard"):
740+
t = t * self.direction
740741
return self.term.vf(t, y, args)
741742

742743
def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> _Control:
@@ -749,7 +750,8 @@ def prod(self, vf: _VF, control: _Control) -> Y:
749750
return self.term.prod(vf, control)
750751

751752
def vf_prod(self, t: RealScalarLike, y: Y, args: Args, control: _Control) -> Y:
752-
t = t * self.direction
753+
with jax.numpy_dtype_promotion("standard"):
754+
t = t * self.direction
753755
return self.term.vf_prod(t, y, args, control)
754756

755757
def is_vf_expensive(

test/test_integrate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,8 @@ def out_structure(self):
631631
return (jax.ShapeDtypeStruct((2, 3), jnp.float64),)
632632

633633
@lx.is_symmetric.register(TestLinearOperator)
634+
@lx.is_positive_semidefinite.register(TestLinearOperator)
635+
@lx.is_negative_semidefinite.register(TestLinearOperator)
634636
def _(operator):
635637
del operator
636638
return False

0 commit comments

Comments
 (0)