@@ -40,19 +40,19 @@ def _solvers_and_orders():
4040# converges to its own limit (i.e. using itself as reference), and then in a
4141# different test check whether that limit is the same as the Euler/Heun limit.
4242@pytest .mark .parametrize ("solver_ctr,noise,theoretical_order" , _solvers_and_orders ())
43- @pytest .mark .parametrize (
44- "dtype" ,
45- (jnp .float64 ,),
46- )
43+ @pytest .mark .parametrize ("dtype" , (jnp .float64 ,))
44+ @pytest .mark .parametrize ("direction" , ("forwards" , "backwards" ))
4745def test_sde_strong_order_new (
48- solver_ctr , noise : Literal ["any" , "com" , "add" ], theoretical_order , dtype
46+ solver_ctr , noise : Literal ["any" , "com" , "add" ], theoretical_order , dtype , direction
4947):
5048 bmkey = jr .key (5678 )
5149 sde_key = jr .key (11 )
5250 num_samples = 100
5351 bmkeys = jr .split (bmkey , num = num_samples )
5452 t0 = 0.3
5553 t1 = 5.3
54+ if direction == "backwards" :
55+ t0 , t1 = t1 , t0
5656
5757 if noise == "add" :
5858 sde = get_time_sde (t0 , t1 , dtype , sde_key , noise_dim = 7 )
@@ -98,14 +98,20 @@ def get_dt_and_controller(level):
9898# This is to avoid recomputing the correct solutions for every solver.
9999solutions = {
100100 "Ito" : {
101- "any" : None ,
102- "com" : None ,
103- "add" : None ,
101+ ("forwards" , "any" ): None ,
102+ ("forwards" , "com" ): None ,
103+ ("forwards" , "add" ): None ,
104+ ("backwards" , "any" ): None ,
105+ ("backwards" , "com" ): None ,
106+ ("backwards" , "add" ): None ,
104107 },
105108 "Stratonovich" : {
106- "any" : None ,
107- "com" : None ,
108- "add" : None ,
109+ ("forwards" , "any" ): None ,
110+ ("forwards" , "com" ): None ,
111+ ("forwards" , "add" ): None ,
112+ ("backwards" , "any" ): None ,
113+ ("backwards" , "com" ): None ,
114+ ("backwards" , "add" ): None ,
109115 },
110116}
111117
@@ -115,15 +121,18 @@ def get_dt_and_controller(level):
115121# and Heun if the solver is Stratonovich.
116122@pytest .mark .parametrize ("solver_ctr,noise,theoretical_order" , _solvers_and_orders ())
117123@pytest .mark .parametrize ("dtype" , (jnp .float64 ,))
124+ @pytest .mark .parametrize ("direction" , ("forwards" , "backwards" ))
118125def test_sde_strong_limit (
119- solver_ctr , noise : Literal ["any" , "com" , "add" ], theoretical_order , dtype
126+ solver_ctr , noise : Literal ["any" , "com" , "add" ], theoretical_order , dtype , direction
120127):
121128 bmkey = jr .key (5678 )
122129 sde_key = jr .key (11 )
123130 num_samples = 100
124131 bmkeys = jr .split (bmkey , num = num_samples )
125132 t0 = 0.3
126133 t1 = 5.3
134+ if direction == "backwards" :
135+ t0 , t1 = t1 , t0
127136
128137 if noise == "add" :
129138 sde = get_time_sde (t0 , t1 , dtype , sde_key , noise_dim = 3 )
@@ -164,16 +173,16 @@ def test_sde_strong_limit(
164173 saveat = diffrax .SaveAt (ts = save_ts )
165174 levy_area = diffrax .SpaceTimeLevyArea # must be common for all solvers
166175
167- if solutions [sol_type ][noise ] is None :
176+ if solutions [sol_type ][( direction , noise ) ] is None :
168177 correct_sol , _ = simple_batch_sde_solve (
169178 bmkeys , sde , ref_solver , levy_area , None , contr_fine , 2 ** - 10 , saveat
170179 )
171- solutions [sol_type ][noise ] = correct_sol
180+ solutions [sol_type ][( direction , noise ) ] = correct_sol
172181 else :
173- correct_sol = solutions [sol_type ][noise ]
182+ correct_sol = solutions [sol_type ][( direction , noise ) ]
174183
175184 sol , _ = simple_batch_sde_solve (
176185 bmkeys , sde , solver_ctr (), levy_area , None , contr_coarse , 2 ** - 10 , saveat
177186 )
178187 error = path_l2_dist (correct_sol , sol )
179- assert error < 0.05
188+ assert error < 0.1
0 commit comments