Skip to content

Commit 39ef1b9

Browse files
committed
fix backwards-in-time solve for srk's
1 parent 7ac1ea0 commit 39ef1b9

3 files changed

Lines changed: 39 additions & 23 deletions

File tree

diffrax/_solver/srk.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,15 @@ def step(
360360
else:
361361
ignore_stage_g = jnp.array(self.tableau.ignore_stage_g)
362362

363-
# time increment
363+
# The internal time step, h, is always positive. The direction is
364+
# handled by _term.py. However, the drift control should be the signed
365+
# value of h, `signed_h`.
366+
# Futher, since _term.py blanket multiplies the output of the control
367+
# by `direction`, we must preserve the symmetry of the SpaceTimeLevyArea,
368+
# H, such that H(h) = H(-h).
364369
h = t1 - t0
370+
signed_h = drift.contr(t0, t1)
371+
direction = jnp.sign(signed_h)
365372

366373
# First the drift related stuff
367374
a = self._embed_a_lower(self.tableau.a, dtype)
@@ -402,7 +409,7 @@ def make_zeros_aux(leaf):
402409
levy_areas = []
403410
if self.tableau.coeffs_hh is not None: # space-time Lévy area
404411
assert isinstance(bm_inc, AbstractSpaceTimeLevyArea)
405-
levy_areas.append(bm_inc.H)
412+
levy_areas.append((direction * bm_inc.H**ω).ω)
406413
b_levy_list.append(jnp.asarray(self.tableau.coeffs_hh.b_sol, dtype=dtype))
407414

408415
if self.tableau.coeffs_kk is not None: # space-time-time Lévy area
@@ -443,7 +450,7 @@ def _comp_g(_t):
443450

444451
if self.tableau.coeffs_hh is not None: # space-time Lévy area
445452
assert isinstance(bm_inc, AbstractSpaceTimeLevyArea)
446-
levylist_kgs.append(diffusion.prod(g0, bm_inc.H))
453+
levylist_kgs.append(diffusion.prod(g0, (direction * bm_inc.H**ω).ω))
447454
a_levy.append(jnp.asarray(self.tableau.coeffs_hh.a, dtype=dtype))
448455

449456
if self.tableau.coeffs_kk is not None: # space-time-time Lévy area
@@ -545,7 +552,7 @@ def stage(
545552
z_j = (y0**ω + _drift_result**ω + _diffusion_result**ω).ω
546553

547554
def compute_and_insert_kf_j(_h_kfs_in):
548-
h_kf_j = drift.vf_prod(t0 + c_j * h, z_j, args, h)
555+
h_kf_j = drift.vf_prod(t0 + c_j * h, z_j, args, signed_h)
549556
return insert_jth_stage(_h_kfs_in, h_kf_j, j)
550557

551558
if ignore_stage_f is None:
@@ -611,7 +618,7 @@ def compute_and_insert_kg_j(_w_kgs_in, _levylist_kgs_in):
611618
# g depends on t. This term is of the form $(g1 - g0) * (0.5*W_n - H_n)$.
612619
if self.tableau.coeffs_hh is not None: # space-time Lévy area
613620
assert isinstance(bm_inc, AbstractSpaceTimeLevyArea)
614-
time_var_contr = (bm_inc.W**ω - 2.0 * bm_inc.H**ω).ω
621+
time_var_contr = (bm_inc.W**ω - 2.0 * direction * bm_inc.H**ω).ω
615622
time_var_term = diffusion.prod(g_delta, time_var_contr)
616623
else:
617624
time_var_term = diffusion.prod(g_delta, bm_inc.W)

test/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def _batch_sde_solve(
160160
else:
161161
struct = w_shape
162162
bm = diffrax.VirtualBrownianTree(
163-
t0=t0,
164-
t1=t1,
163+
t0=jnp.minimum(t0, t1),
164+
t1=jnp.maximum(t0, t1),
165165
shape=struct,
166166
tol=bm_tol,
167167
key=key,

test/test_sde1.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,19 @@ def _solvers_and_orders():
4040
# converges to its own limit (i.e. using itself as reference), and then in a
4141
# different test check whether that limit is the same as the Euler/Heun limit.
4242
@pytest.mark.parametrize("solver_ctr,noise,theoretical_order", _solvers_and_orders())
43-
@pytest.mark.parametrize(
44-
"dtype",
45-
(jnp.float64,),
46-
)
43+
@pytest.mark.parametrize("dtype", (jnp.float64,))
44+
@pytest.mark.parametrize("direction", ("forwards", "backwards"))
4745
def test_sde_strong_order_new(
48-
solver_ctr, noise: Literal["any", "com", "add"], theoretical_order, dtype
46+
solver_ctr, noise: Literal["any", "com", "add"], theoretical_order, dtype, direction
4947
):
5048
bmkey = jr.key(5678)
5149
sde_key = jr.key(11)
5250
num_samples = 100
5351
bmkeys = jr.split(bmkey, num=num_samples)
5452
t0 = 0.3
5553
t1 = 5.3
54+
if direction == "backwards":
55+
t0, t1 = t1, t0
5656

5757
if noise == "add":
5858
sde = get_time_sde(t0, t1, dtype, sde_key, noise_dim=7)
@@ -98,14 +98,20 @@ def get_dt_and_controller(level):
9898
# This is to avoid recomputing the correct solutions for every solver.
9999
solutions = {
100100
"Ito": {
101-
"any": None,
102-
"com": None,
103-
"add": None,
101+
("forwards", "any"): None,
102+
("forwards", "com"): None,
103+
("forwards", "add"): None,
104+
("backwards", "any"): None,
105+
("backwards", "com"): None,
106+
("backwards", "add"): None,
104107
},
105108
"Stratonovich": {
106-
"any": None,
107-
"com": None,
108-
"add": None,
109+
("forwards", "any"): None,
110+
("forwards", "com"): None,
111+
("forwards", "add"): None,
112+
("backwards", "any"): None,
113+
("backwards", "com"): None,
114+
("backwards", "add"): None,
109115
},
110116
}
111117

@@ -115,15 +121,18 @@ def get_dt_and_controller(level):
115121
# and Heun if the solver is Stratonovich.
116122
@pytest.mark.parametrize("solver_ctr,noise,theoretical_order", _solvers_and_orders())
117123
@pytest.mark.parametrize("dtype", (jnp.float64,))
124+
@pytest.mark.parametrize("direction", ("forwards", "backwards"))
118125
def test_sde_strong_limit(
119-
solver_ctr, noise: Literal["any", "com", "add"], theoretical_order, dtype
126+
solver_ctr, noise: Literal["any", "com", "add"], theoretical_order, dtype, direction
120127
):
121128
bmkey = jr.key(5678)
122129
sde_key = jr.key(11)
123130
num_samples = 100
124131
bmkeys = jr.split(bmkey, num=num_samples)
125132
t0 = 0.3
126133
t1 = 5.3
134+
if direction == "backwards":
135+
t0, t1 = t1, t0
127136

128137
if noise == "add":
129138
sde = get_time_sde(t0, t1, dtype, sde_key, noise_dim=3)
@@ -164,16 +173,16 @@ def test_sde_strong_limit(
164173
saveat = diffrax.SaveAt(ts=save_ts)
165174
levy_area = diffrax.SpaceTimeLevyArea # must be common for all solvers
166175

167-
if solutions[sol_type][noise] is None:
176+
if solutions[sol_type][(direction, noise)] is None:
168177
correct_sol, _ = simple_batch_sde_solve(
169178
bmkeys, sde, ref_solver, levy_area, None, contr_fine, 2**-10, saveat
170179
)
171-
solutions[sol_type][noise] = correct_sol
180+
solutions[sol_type][(direction, noise)] = correct_sol
172181
else:
173-
correct_sol = solutions[sol_type][noise]
182+
correct_sol = solutions[sol_type][(direction, noise)]
174183

175184
sol, _ = simple_batch_sde_solve(
176185
bmkeys, sde, solver_ctr(), levy_area, None, contr_coarse, 2**-10, saveat
177186
)
178187
error = path_l2_dist(correct_sol, sol)
179-
assert error < 0.05
188+
assert error < 0.1

0 commit comments

Comments
 (0)