@@ -79,47 +79,39 @@ def correctness_forward(
7979 reference_implementation = E3NNTensorProduct
8080
8181 result = {"thresh" : correctness_threshold , "batch_size" : batch_size }
82-
8382 in1 , in2 , weights , out = get_random_buffers_forward (problem , batch_size , prng_seed )
83+ outputs = []
8484
85- # run reference (always in mul_ir)
86- ref_tp = reference_implementation (problem )
87-
88- ref_out = out .copy ()
89- ref_tp .forward_cpu (
90- L1_in = in1 .copy (), L2_in = in2 .copy (), L3_out = ref_out , weights = weights .copy ()
91- )
92-
93- weights_copy = weights .copy ()
94- if problem .shared_weights and test_implementation == CUETensorProduct :
95- weights_copy = weights [np .newaxis , :]
96-
97- # run test (may require ir_mul conversion)
98- test_tp = instantiate_implementation (test_implementation , problem )
99- test_layout = getattr (test_tp .config , "layout" , "mul_ir" )
100-
101- test_in1 = in1 .copy ()
102- test_in2 = in2 .copy ()
103- test_out = out .copy ()
104-
105- if test_layout == "ir_mul" :
106- test_in1 = IrrepLayoutUtils .transpose_irrep_layout (
107- test_in1 , problem .irreps_in1 , "mul_ir" , "ir_mul"
108- )
109- test_in2 = IrrepLayoutUtils .transpose_irrep_layout (
110- test_in2 , problem .irreps_in2 , "mul_ir" , "ir_mul"
111- )
112-
113- test_tp .forward_cpu (
114- L1_in = test_in1 , L2_in = test_in2 , L3_out = test_out , weights = weights_copy
115- )
85+ for i , impl in enumerate ([test_implementation , reference_implementation ]):
86+ is_test_impl = (i == 0 )
87+ tp = instantiate_implementation (impl , problem )
88+ uses_cue = impl == CUETensorProduct or isinstance (tp , CUETensorProduct )
89+ run_in1 , run_in2 , run_weights , run_out = [ buf .copy () for buf in (in1 , in2 , weights , out ) ]
90+
91+ if problem .shared_weights and uses_cue :
92+ run_weights = run_weights [np .newaxis , :]
93+
94+ # Transpose inputs, if necessary, for the test implementation
95+ if is_test_impl :
96+ run_in1 , run_in2 = [
97+ IrrepLayoutUtils .transpose_irrep_layout (
98+ arr , irreps , "mul_ir" , tp .config .layout
99+ ) for arr , irreps in zip (
100+ (run_in1 , run_in2 ),
101+ (problem .irreps_in1 , problem .irreps_in2 )
102+ )
103+ ]
104+
105+ tp .forward_cpu (L1_in = run_in1 , L2_in = run_in2 , L3_out = run_out , weights = run_weights )
106+
107+ if is_test_impl :
108+ run_out = IrrepLayoutUtils .transpose_irrep_layout (
109+ run_out , problem .irreps_out , tp .config .layout , "mul_ir"
110+ )
116111
117- if test_layout == "ir_mul" :
118- test_out = IrrepLayoutUtils .transpose_irrep_layout (
119- test_out , problem .irreps_out , "ir_mul" , "mul_ir"
120- )
112+ outputs .append (run_out )
121113
122- for name , to_check , ground_truth in [("output" , ref_out , test_out )]:
114+ for name , to_check , ground_truth in [("output" , outputs [ 0 ], outputs [ 1 ] )]:
123115 result [name ] = check_similiarity (
124116 name , to_check , ground_truth , correctness_threshold
125117 )
@@ -142,87 +134,72 @@ def correctness_backward(
142134
143135 result = {"thresh" : correctness_threshold , "batch_size" : batch_size }
144136
145- # run reference
146137 in1 , in2 , out_grad , weights , weights_grad , in1_grad , in2_grad = (
147138 get_random_buffers_backward (problem , batch_size , prng_seed )
148139 )
149140
150- ref_tp = reference_implementation (problem )
151-
152- ref_weights_grad = weights_grad .copy ()
153- ref_in1_grad = in1_grad .copy ()
154- ref_in2_grad = in2_grad .copy ()
155-
156- ref_tp .backward_cpu (
157- L1_in = in1 .copy (),
158- L1_grad = ref_in1_grad ,
159- L2_in = in2 .copy (),
160- L2_grad = ref_in2_grad ,
161- L3_grad = out_grad .copy (),
162- weights = weights .copy (),
163- weights_grad = ref_weights_grad ,
164- )
165-
166- # run test version (may require ir_mul conversion)
167- test_weights_grad = weights_grad .copy ()
168- test_in1_grad = in1_grad .copy ()
169- test_in2_grad = in2_grad .copy ()
170-
171- weights_copy = weights .copy ()
172-
173- if problem .shared_weights and test_implementation == CUETensorProduct :
174- weights_copy = weights [np .newaxis , :]
175- test_weights_grad = test_weights_grad [np .newaxis , :]
176-
177- test_tp = instantiate_implementation (test_implementation , problem )
178- test_layout = getattr (test_tp .config , "layout" , "mul_ir" )
179-
180- test_in1 = in1 .copy ()
181- test_in2 = in2 .copy ()
182- test_L3_grad = out_grad .copy ()
141+ grads = []
142+ for i , impl in enumerate ([test_implementation , reference_implementation ]):
143+ is_test_impl = i == 0
144+ tp = instantiate_implementation (impl , problem )
183145
184- if test_layout == "ir_mul" :
185- test_in1 = IrrepLayoutUtils .transpose_irrep_layout (
186- test_in1 , problem .irreps_in1 , "mul_ir" , "ir_mul"
187- )
188- test_in2 = IrrepLayoutUtils .transpose_irrep_layout (
189- test_in2 , problem .irreps_in2 , "mul_ir" , "ir_mul"
190- )
191- test_L3_grad = IrrepLayoutUtils .transpose_irrep_layout (
192- test_L3_grad , problem .irreps_out , "mul_ir" , "ir_mul"
146+ run_in1 , run_in2 , run_L3_grad , run_weights , run_weights_grad , run_in1_grad , run_in2_grad = [
147+ buf .copy ()
148+ for buf in (in1 , in2 , out_grad , weights , weights_grad , in1_grad , in2_grad )
149+ ]
150+
151+ uses_cue = impl == CUETensorProduct or isinstance (tp , CUETensorProduct )
152+ if problem .shared_weights and uses_cue :
153+ run_weights = run_weights [np .newaxis , :]
154+ run_weights_grad = run_weights_grad [np .newaxis , :]
155+
156+ if is_test_impl :
157+ run_in1 , run_in2 , run_L3_grad = [
158+ IrrepLayoutUtils .transpose_irrep_layout (
159+ arr , irreps , "mul_ir" , tp .config .layout
160+ )
161+ for arr , irreps in zip (
162+ (run_in1 , run_in2 , run_L3_grad ),
163+ (problem .irreps_in1 , problem .irreps_in2 , problem .irreps_out ),
164+ )
165+ ]
166+
167+ tp .backward_cpu (
168+ L1_in = run_in1 ,
169+ L1_grad = run_in1_grad ,
170+ L2_in = run_in2 ,
171+ L2_grad = run_in2_grad ,
172+ L3_grad = run_L3_grad ,
173+ weights = run_weights ,
174+ weights_grad = run_weights_grad ,
193175 )
194176
195- test_tp .backward_cpu (
196- L1_in = test_in1 ,
197- L1_grad = test_in1_grad ,
198- L2_in = test_in2 ,
199- L2_grad = test_in2_grad ,
200- L3_grad = test_L3_grad ,
201- weights = weights_copy ,
202- weights_grad = test_weights_grad ,
203- )
177+ if is_test_impl :
178+ run_in1_grad , run_in2_grad = [
179+ IrrepLayoutUtils .transpose_irrep_layout (
180+ arr , irreps , tp .config .layout , "mul_ir"
181+ )
182+ for arr , irreps in zip (
183+ (run_in1_grad , run_in2_grad ),
184+ (problem .irreps_in1 , problem .irreps_in2 ),
185+ )
186+ ]
204187
205- if test_layout == "ir_mul" :
206- test_in1_grad = IrrepLayoutUtils .transpose_irrep_layout (
207- test_in1_grad , problem .irreps_in1 , "ir_mul" , "mul_ir"
208- )
209- test_in2_grad = IrrepLayoutUtils .transpose_irrep_layout (
210- test_in2_grad , problem .irreps_in2 , "ir_mul" , "mul_ir"
211- )
188+ if problem .shared_weights :
189+ run_weights_grad = run_weights_grad .squeeze ()
190+
191+ grads .append ((run_weights_grad , run_in1_grad , run_in2_grad ))
212192
213193 weight_threshold = (
214194 correctness_threshold * batch_size
215195 if problem .shared_weights
216196 else correctness_threshold
217197 )
218198
219- if problem .shared_weights :
220- test_weights_grad = test_weights_grad .squeeze ()
221-
222199 for name , to_check , ground_truth , threshold in [
223- ("weight_grad" , test_weights_grad , ref_weights_grad , weight_threshold ),
224- ("in1_grad" , test_in1_grad , ref_in1_grad , correctness_threshold ),
225- ("in2_grad" , test_in2_grad , ref_in2_grad , correctness_threshold ),
200+ ("weight_grad" , grads [ 0 ][ 0 ], grads [ 1 ][ 0 ] , weight_threshold ),
201+ ("in1_grad" , grads [ 0 ][ 1 ], grads [ 1 ][ 1 ] , correctness_threshold ),
202+ ("in2_grad" , grads [ 0 ][ 2 ], grads [ 1 ][ 2 ] , correctness_threshold ),
226203 ]:
227204 result [name ] = check_similiarity (name , to_check , ground_truth , threshold )
228205
@@ -254,9 +231,8 @@ def correctness_double_backward(
254231 result = {"thresh" : correctness_threshold , "batch_size" : batch_size }
255232
256233 tensors = []
257- for is_test_impl , impl in enumerate (
258- [test_implementation , reference_implementation ]
259- ):
234+ for i , impl in enumerate ([test_implementation , reference_implementation ]):
235+ is_test_impl = i == 0
260236 tp = instantiate_implementation (impl , problem )
261237 weights_reordered = tp .reorder_weights_from_e3nn (
262238 weights , has_batch_dim = not problem .shared_weights
@@ -268,31 +244,26 @@ def correctness_double_backward(
268244 if impl == CUETensorProduct and problem .shared_weights :
269245 weights_reordered = weights_reordered [np .newaxis , :]
270246
271- tp_layout = getattr (tp .config , "layout" , "mul_ir" )
272- apply_test_layout = is_test_impl == 0 and tp_layout == "ir_mul"
273-
274- db_in1 = in1
275- db_in2 = in2
276- db_out_grad = out_grad
277- db_in1_dgrad = in1_dgrad
278- db_in2_dgrad = in2_dgrad
279-
280- if apply_test_layout :
281- db_in1 = IrrepLayoutUtils .transpose_irrep_layout (
282- in1 , problem .irreps_in1 , "mul_ir" , "ir_mul"
283- )
284- db_in2 = IrrepLayoutUtils .transpose_irrep_layout (
285- in2 , problem .irreps_in2 , "mul_ir" , "ir_mul"
286- )
287- db_out_grad = IrrepLayoutUtils .transpose_irrep_layout (
288- out_grad , problem .irreps_out , "mul_ir" , "ir_mul"
289- )
290- db_in1_dgrad = IrrepLayoutUtils .transpose_irrep_layout (
291- in1_dgrad , problem .irreps_in1 , "mul_ir" , "ir_mul"
292- )
293- db_in2_dgrad = IrrepLayoutUtils .transpose_irrep_layout (
294- in2_dgrad , problem .irreps_in2 , "mul_ir" , "ir_mul"
295- )
247+ db_in1 , db_in2 , db_out_grad , db_in1_dgrad , db_in2_dgrad = [
248+ buf .copy () for buf in (in1 , in2 , out_grad , in1_dgrad , in2_dgrad )
249+ ]
250+
251+ if is_test_impl :
252+ db_in1 , db_in2 , db_out_grad , db_in1_dgrad , db_in2_dgrad = [
253+ IrrepLayoutUtils .transpose_irrep_layout (
254+ arr , irreps , "mul_ir" , tp .config .layout
255+ )
256+ for arr , irreps in zip (
257+ (db_in1 , db_in2 , db_out_grad , db_in1_dgrad , db_in2_dgrad ),
258+ (
259+ problem .irreps_in1 ,
260+ problem .irreps_in2 ,
261+ problem .irreps_out ,
262+ problem .irreps_in1 ,
263+ problem .irreps_in2 ,
264+ ),
265+ )
266+ ]
296267
297268 in1_grad , in2_grad , weights_grad , out_dgrad = tp .double_backward_cpu (
298269 db_in1 ,
@@ -304,16 +275,16 @@ def correctness_double_backward(
304275 db_in2_dgrad ,
305276 )
306277
307- if apply_test_layout :
308- out_dgrad = IrrepLayoutUtils . transpose_irrep_layout (
309- out_dgrad , problem . irreps_out , "ir_mul" , "mul_ir"
310- )
311- in1_grad = IrrepLayoutUtils . transpose_irrep_layout (
312- in1_grad , problem . irreps_in1 , "ir_mul" , "mul_ir"
313- )
314- in2_grad = IrrepLayoutUtils . transpose_irrep_layout (
315- in2_grad , problem . irreps_in2 , "ir_mul" , "mul_ir"
316- )
278+ if is_test_impl :
279+ out_dgrad , in1_grad , in2_grad = [
280+ IrrepLayoutUtils . transpose_irrep_layout (
281+ arr , irreps , tp . config . layout , "mul_ir"
282+ )
283+ for arr , irreps in zip (
284+ ( out_dgrad , in1_grad , in2_grad ),
285+ ( problem . irreps_out , problem . irreps_in1 , problem . irreps_in2 ),
286+ )
287+ ]
317288
318289 tensors .append (
319290 (
0 commit comments