77from itertools import chain , product
88
99class ConvCorrectness :
10+ def thresh (self , direction ):
11+ return {
12+ "fwd" : 1e-5 ,
13+ "bwd" : 3e-4 ,
14+ "double_bwd" : 3e-4
15+ }[direction ]
16+
17+
1018 def check_result (self , result , fieldname ):
1119 with check :
1220 error = result [fieldname ]["diff_Linf_norm" ]
1321 thresh = result ["thresh" ]
14- assert result [fieldname ]["pass" ], f"{ fieldname } observed error={ error :.2f } >= { thresh } "
15-
22+ assert result [fieldname ]["pass" ], f"{ fieldname } observed error={ error :.5f } >= { thresh } "
23+
1624 @pytest .fixture (params = [np .float32 , np .float64 ], ids = ['F32' , 'F64' ], scope = 'class' )
1725 def dtype (self , request ):
1826 return request .param
@@ -48,7 +56,7 @@ def test_tp_fwd(self, conv_object, graph):
4856 return
4957
5058 result = conv_object .test_correctness_forward (graph ,
51- thresh = 3e-05 ,
59+ thresh = self . thresh ( "fwd" ) ,
5260 prng_seed = 12345 ,
5361 reference_implementation = None )
5462
@@ -60,7 +68,7 @@ def test_tp_bwd(self, conv_object, graph):
6068 return
6169
6270 result = conv_object .test_correctness_backward (graph ,
63- thresh = 3e-04 ,
71+ thresh = self . thresh ( "bwd" ) ,
6472 prng_seed = 12345 ,
6573 reference_implementation = None )
6674
@@ -74,7 +82,7 @@ def test_tp_double_bwd(self, conv_object, graph):
7482 return
7583
7684 result = conv_object .test_correctness_double_backward (graph ,
77- thresh = 3e-04 ,
85+ thresh = self . thresh ( "double_bwd" ) ,
7886 prng_seed = 12345 ,
7987 reference_implementation = None )
8088
@@ -140,4 +148,27 @@ def problem(self, request, dtype):
140148 return oeq .TPProblem (f"{ m [0 ]} x{ i [0 ]} e" , f"{ m [1 ]} x{ i [1 ]} e" , f"{ m [2 ]} x{ i [2 ]} e" ,
141149 instructions , shared_weights = False ,
142150 internal_weights = False ,
143- irrep_dtype = dtype , weight_dtype = dtype )
151+ irrep_dtype = dtype , weight_dtype = dtype )
152+
153+
154+ class TestAtomicSharedWeights (ConvCorrectness ):
155+ from openequivariance .benchmark .benchmark_configs import mace_problems , diffdock_configs
156+ problems = [mace_problems [0 ], diffdock_configs [0 ]]
157+
158+ def thresh (self , direction ):
159+ return {
160+ "fwd" : 1e-5 ,
161+ "bwd" : 5e-2 , # Expect higher errors for shared weights
162+ "double_bwd" : 5e-2
163+ }[direction ]
164+
165+ @pytest .fixture (params = problems , ids = lambda x : x .label , scope = "class" )
166+ def problem (self , request , dtype ):
167+ problem = request .param
168+ problem .irrep_dtype , problem .weight_dtype = dtype , dtype
169+ problem .shared_weights = True
170+ return problem
171+
172+ @pytest .fixture (scope = 'class' )
173+ def conv_object (self , request , problem ):
174+ return oeq .TensorProductConv (problem , deterministic = False )
0 commit comments