11from typing import Optional , Union
22
3- from openequivariance .core .TensorProductBase import TensorProductBase
4- from openequivariance .core .e3nn_lite import TPProblem
3+ import numpy as np
4+ import numpy .linalg as la
5+
56from openequivariance ._torch .CUETensorProduct import CUETensorProduct
7+ from openequivariance .benchmark .logging_utils import bcolors , getLogger
68from openequivariance .benchmark .random_buffer_utils import (
7- get_random_buffers_forward ,
89 get_random_buffers_backward ,
910 get_random_buffers_double_backward ,
11+ get_random_buffers_forward ,
1012)
11-
12- from openequivariance .benchmark .logging_utils import getLogger , bcolors
13- import numpy as np
14- import numpy .linalg as la
13+ from openequivariance .core .e3nn_lite import TPProblem
14+ from openequivariance .core .TensorProductBase import TensorProductBase
15+ from openequivariance .core .utils import IrrepLayoutUtils
1516
1617logger = getLogger ()
1718
@@ -81,7 +82,7 @@ def correctness_forward(
8182
8283 in1 , in2 , weights , out = get_random_buffers_forward (problem , batch_size , prng_seed )
8384
84- # run reference
85+ # run reference (always in mul_ir)
8586 ref_tp = reference_implementation (problem )
8687
8788 ref_out = out .copy ()
@@ -93,13 +94,31 @@ def correctness_forward(
9394 if problem .shared_weights and test_implementation == CUETensorProduct :
9495 weights_copy = weights [np .newaxis , :]
9596
96- # run test
97+ # run test (may require ir_mul conversion)
9798 test_tp = instantiate_implementation (test_implementation , problem )
99+ test_layout = getattr (test_tp .config , "layout" , "mul_ir" )
100+
101+ test_in1 = in1 .copy ()
102+ test_in2 = in2 .copy ()
98103 test_out = out .copy ()
104+
105+ if test_layout == "ir_mul" :
106+ test_in1 = IrrepLayoutUtils .transpose_irrep_layout (
107+ test_in1 , problem .irreps_in1 , "mul_ir" , "ir_mul"
108+ )
109+ test_in2 = IrrepLayoutUtils .transpose_irrep_layout (
110+ test_in2 , problem .irreps_in2 , "mul_ir" , "ir_mul"
111+ )
112+
99113 test_tp .forward_cpu (
100- L1_in = in1 . copy () , L2_in = in2 . copy () , L3_out = test_out , weights = weights_copy
114+ L1_in = test_in1 , L2_in = test_in2 , L3_out = test_out , weights = weights_copy
101115 )
102116
117+ if test_layout == "ir_mul" :
118+ test_out = IrrepLayoutUtils .transpose_irrep_layout (
119+ test_out , problem .irreps_out , "ir_mul" , "mul_ir"
120+ )
121+
103122 for name , to_check , ground_truth in [("output" , ref_out , test_out )]:
104123 result [name ] = check_similiarity (
105124 name , to_check , ground_truth , correctness_threshold
@@ -144,7 +163,7 @@ def correctness_backward(
144163 weights_grad = ref_weights_grad ,
145164 )
146165
147- # run test version
166+ # run test version (may require ir_mul conversion)
148167 test_weights_grad = weights_grad .copy ()
149168 test_in1_grad = in1_grad .copy ()
150169 test_in2_grad = in2_grad .copy ()
@@ -156,16 +175,41 @@ def correctness_backward(
156175 test_weights_grad = test_weights_grad [np .newaxis , :]
157176
158177 test_tp = instantiate_implementation (test_implementation , problem )
178+ test_layout = getattr (test_tp .config , "layout" , "mul_ir" )
179+
180+ test_in1 = in1 .copy ()
181+ test_in2 = in2 .copy ()
182+ test_L3_grad = out_grad .copy ()
183+
184+ if test_layout == "ir_mul" :
185+ test_in1 = IrrepLayoutUtils .transpose_irrep_layout (
186+ test_in1 , problem .irreps_in1 , "mul_ir" , "ir_mul"
187+ )
188+ test_in2 = IrrepLayoutUtils .transpose_irrep_layout (
189+ test_in2 , problem .irreps_in2 , "mul_ir" , "ir_mul"
190+ )
191+ test_L3_grad = IrrepLayoutUtils .transpose_irrep_layout (
192+ test_L3_grad , problem .irreps_out , "mul_ir" , "ir_mul"
193+ )
194+
159195 test_tp .backward_cpu (
160- L1_in = in1 . copy () ,
196+ L1_in = test_in1 ,
161197 L1_grad = test_in1_grad ,
162- L2_in = in2 . copy () ,
198+ L2_in = test_in2 ,
163199 L2_grad = test_in2_grad ,
164- L3_grad = out_grad . copy () ,
200+ L3_grad = test_L3_grad ,
165201 weights = weights_copy ,
166202 weights_grad = test_weights_grad ,
167203 )
168204
205+ if test_layout == "ir_mul" :
206+ test_in1_grad = IrrepLayoutUtils .transpose_irrep_layout (
207+ test_in1_grad , problem .irreps_in1 , "ir_mul" , "mul_ir"
208+ )
209+ test_in2_grad = IrrepLayoutUtils .transpose_irrep_layout (
210+ test_in2_grad , problem .irreps_in2 , "ir_mul" , "mul_ir"
211+ )
212+
169213 weight_threshold = (
170214 correctness_threshold * batch_size
171215 if problem .shared_weights
@@ -210,7 +254,9 @@ def correctness_double_backward(
210254 result = {"thresh" : correctness_threshold , "batch_size" : batch_size }
211255
212256 tensors = []
213- for _ , impl in enumerate ([test_implementation , reference_implementation ]):
257+ for is_test_impl , impl in enumerate (
258+ [test_implementation , reference_implementation ]
259+ ):
214260 tp = instantiate_implementation (impl , problem )
215261 weights_reordered = tp .reorder_weights_from_e3nn (
216262 weights , has_batch_dim = not problem .shared_weights
@@ -222,15 +268,53 @@ def correctness_double_backward(
222268 if impl == CUETensorProduct and problem .shared_weights :
223269 weights_reordered = weights_reordered [np .newaxis , :]
224270
271+ tp_layout = getattr (tp .config , "layout" , "mul_ir" )
272+ apply_test_layout = is_test_impl == 0 and tp_layout == "ir_mul"
273+
274+ db_in1 = in1
275+ db_in2 = in2
276+ db_out_grad = out_grad
277+ db_in1_dgrad = in1_dgrad
278+ db_in2_dgrad = in2_dgrad
279+
280+ if apply_test_layout :
281+ db_in1 = IrrepLayoutUtils .transpose_irrep_layout (
282+ in1 , problem .irreps_in1 , "mul_ir" , "ir_mul"
283+ )
284+ db_in2 = IrrepLayoutUtils .transpose_irrep_layout (
285+ in2 , problem .irreps_in2 , "mul_ir" , "ir_mul"
286+ )
287+ db_out_grad = IrrepLayoutUtils .transpose_irrep_layout (
288+ out_grad , problem .irreps_out , "mul_ir" , "ir_mul"
289+ )
290+ db_in1_dgrad = IrrepLayoutUtils .transpose_irrep_layout (
291+ in1_dgrad , problem .irreps_in1 , "mul_ir" , "ir_mul"
292+ )
293+ db_in2_dgrad = IrrepLayoutUtils .transpose_irrep_layout (
294+ in2_dgrad , problem .irreps_in2 , "mul_ir" , "ir_mul"
295+ )
296+
225297 in1_grad , in2_grad , weights_grad , out_dgrad = tp .double_backward_cpu (
226- in1 ,
227- in2 ,
228- out_grad ,
298+ db_in1 ,
299+ db_in2 ,
300+ db_out_grad ,
229301 weights_reordered ,
230302 weights_dgrad_reordered ,
231- in1_dgrad ,
232- in2_dgrad ,
303+ db_in1_dgrad ,
304+ db_in2_dgrad ,
233305 )
306+
307+ if apply_test_layout :
308+ out_dgrad = IrrepLayoutUtils .transpose_irrep_layout (
309+ out_dgrad , problem .irreps_out , "ir_mul" , "mul_ir"
310+ )
311+ in1_grad = IrrepLayoutUtils .transpose_irrep_layout (
312+ in1_grad , problem .irreps_in1 , "ir_mul" , "mul_ir"
313+ )
314+ in2_grad = IrrepLayoutUtils .transpose_irrep_layout (
315+ in2_grad , problem .irreps_in2 , "ir_mul" , "mul_ir"
316+ )
317+
234318 tensors .append (
235319 (
236320 out_dgrad ,
0 commit comments