@@ -144,13 +144,6 @@ def test_correctness_forward(
144144 check_reproducible = True ,
145145 high_precision_ref = False ,
146146 ):
147- def maybe_transpose_input_for_test_impl (x , irreps ):
148- if self .config .layout == "ir_mul" :
149- return IrrepLayoutUtils .transpose_irrep_layout (
150- x , irreps , "mul_ir" , "ir_mul"
151- )
152- return x
153-
154147 if reference_implementation is None :
155148 from openequivariance ._torch .E3NNConv import E3NNConv
156149
@@ -192,23 +185,29 @@ def maybe_transpose_input_for_test_impl(x, irreps):
192185
193186 ref_out [:] = ref_tp .forward (** args ).cpu ().numpy ()
194187
195- test_out = out .copy ()
188+ run_in1 , run_in2 , run_weights , test_out = [
189+ buf .copy () for buf in (in1 , in2 , weights , out )
190+ ]
191+ run_in1 , run_in2 = [
192+ IrrepLayoutUtils .transpose_irrep_layout (
193+ arr , irreps , "mul_ir" , self .config .layout
194+ )
195+ for arr , irreps in zip (
196+ (run_in1 , run_in2 ),
197+ (self .config .irreps_in1 , self .config .irreps_in2 ),
198+ )
199+ ]
196200 self .forward_cpu (
197- L1_in = maybe_transpose_input_for_test_impl (
198- in1 .copy (), self .config .irreps_in1
199- ),
200- L2_in = maybe_transpose_input_for_test_impl (
201- in2 .copy (), self .config .irreps_in2
202- ),
203- weights = weights .copy (),
201+ L1_in = run_in1 ,
202+ L2_in = run_in2 ,
203+ weights = run_weights ,
204204 L3_out = test_out ,
205205 graph = graph ,
206206 )
207207
208- if self .config .layout == "ir_mul" :
209- test_out = IrrepLayoutUtils .transpose_irrep_layout (
210- test_out , self .config .irreps_out , "ir_mul" , "mul_ir"
211- )
208+ test_out = IrrepLayoutUtils .transpose_irrep_layout (
209+ test_out , self .config .irreps_out , self .config .layout , "mul_ir"
210+ )
212211
213212 for name , to_check , ground_truth in [("output" , ref_out , test_out )]:
214213 result [name ] = check_similiarity (name , to_check , ground_truth , thresh )
@@ -221,22 +220,29 @@ def maybe_transpose_input_for_test_impl(x, irreps):
221220
222221 for i in range (num_trials ):
223222 repeated_run = out .copy ()
223+ rep_in1 , rep_in2 , rep_weights = [
224+ buf .copy () for buf in (in1 , in2 , weights )
225+ ]
226+ rep_in1 , rep_in2 = [
227+ IrrepLayoutUtils .transpose_irrep_layout (
228+ arr , irreps , "mul_ir" , self .config .layout
229+ )
230+ for arr , irreps in zip (
231+ (rep_in1 , rep_in2 ),
232+ (self .config .irreps_in1 , self .config .irreps_in2 ),
233+ )
234+ ]
224235 self .forward_cpu (
225- L1_in = maybe_transpose_input_for_test_impl (
226- in1 .copy (), self .config .irreps_in1
227- ),
228- L2_in = maybe_transpose_input_for_test_impl (
229- in2 .copy (), self .config .irreps_in2
230- ),
231- weights = weights .copy (),
236+ L1_in = rep_in1 ,
237+ L2_in = rep_in2 ,
238+ weights = rep_weights ,
232239 L3_out = repeated_run ,
233240 graph = graph ,
234241 )
235242
236- if self .config .layout == "ir_mul" :
237- repeated_run = IrrepLayoutUtils .transpose_irrep_layout (
238- repeated_run , self .config .irreps_out , "ir_mul" , "mul_ir"
239- )
243+ repeated_run = IrrepLayoutUtils .transpose_irrep_layout (
244+ repeated_run , self .config .irreps_out , self .config .layout , "mul_ir"
245+ )
240246
241247 for name , to_check , ground_truth in [
242248 ("output" , repeated_run , test_out )
@@ -413,13 +419,6 @@ def test_correctness_backward(
413419 reference_implementation = None ,
414420 high_precision_ref = False ,
415421 ):
416- def maybe_transpose_input_for_test_impl (x , irreps ):
417- if self .config .layout == "ir_mul" :
418- return IrrepLayoutUtils .transpose_irrep_layout (
419- x , irreps , "mul_ir" , "ir_mul"
420- )
421- return x
422-
423422 if reference_implementation is None :
424423 from openequivariance ._torch .E3NNConv import E3NNConv
425424
@@ -469,34 +468,39 @@ def maybe_transpose_input_for_test_impl(x, irreps):
469468 test_in1_grad = in1_grad .copy ()
470469 test_in2_grad = in2_grad .copy ()
471470
472- test_L3_grad = out_grad .copy ()
473- if self .config .layout == "ir_mul" :
474- test_L3_grad = IrrepLayoutUtils .transpose_irrep_layout (
475- test_L3_grad , self .config .irreps_out , "mul_ir" , "ir_mul"
471+ test_in1 , test_in2 , test_L3_grad = [
472+ buf .copy () for buf in (in1 , in2 , out_grad )
473+ ]
474+ test_in1 , test_in2 , test_L3_grad = [
475+ IrrepLayoutUtils .transpose_irrep_layout (
476+ arr , irreps , "mul_ir" , self .config .layout
477+ )
478+ for arr , irreps in zip (
479+ (test_in1 , test_in2 , test_L3_grad ),
480+ (self .config .irreps_in1 , self .config .irreps_in2 , self .config .irreps_out ),
476481 )
482+ ]
477483
478484 self .backward_cpu (
479- L1_in = maybe_transpose_input_for_test_impl (
480- in1 .copy (), self .config .irreps_in1
481- ),
485+ L1_in = test_in1 ,
482486 L1_grad = test_in1_grad ,
483- L2_in = maybe_transpose_input_for_test_impl (
484- in2 .copy (), self .config .irreps_in2
485- ),
487+ L2_in = test_in2 ,
486488 L2_grad = test_in2_grad ,
487489 L3_grad = test_L3_grad ,
488490 weights = weights .copy (),
489491 weights_grad = test_weights_grad ,
490492 graph = graph ,
491493 )
492494
493- if self . config . layout == "ir_mul" :
494- test_in1_grad = IrrepLayoutUtils .transpose_irrep_layout (
495- test_in1_grad , self .config .irreps_in1 , "ir_mul" , "mul_ir"
495+ test_in1_grad , test_in2_grad = [
496+ IrrepLayoutUtils .transpose_irrep_layout (
497+ arr , irreps , self .config .layout , "mul_ir"
496498 )
497- test_in2_grad = IrrepLayoutUtils .transpose_irrep_layout (
498- test_in2_grad , self .config .irreps_in2 , "ir_mul" , "mul_ir"
499+ for arr , irreps in zip (
500+ (test_in1_grad , test_in2_grad ),
501+ (self .config .irreps_in1 , self .config .irreps_in2 ),
499502 )
503+ ]
500504
501505 for name , to_check , ground_truth , threshold in [
502506 ("weight_grad" , test_weights_grad , ref_weights_grad , thresh ),
@@ -515,13 +519,6 @@ def test_correctness_double_backward(
515519 reference_implementation = None ,
516520 high_precision_ref = False ,
517521 ):
518- def maybe_transpose_input_for_test_impl (tp , x , irreps ):
519- if tp is self and tp .config .layout == "ir_mul" :
520- return IrrepLayoutUtils .transpose_irrep_layout (
521- x , irreps , "mul_ir" , "ir_mul"
522- )
523- return x
524-
525522 buffers = get_random_buffers_double_backward_conv (
526523 self .config , graph .node_count , graph .nnz , prng_seed
527524 )
@@ -542,6 +539,7 @@ def maybe_transpose_input_for_test_impl(tp, x, irreps):
542539 result = {"thresh" : thresh }
543540 tensors = []
544541 for i , tp in enumerate ([self , reference_tp ]):
542+ is_test_impl = i == 0
545543 buffers_copy = [buf .copy () for buf in buffers ]
546544
547545 if i == 1 and high_precision_ref :
@@ -558,21 +556,25 @@ def maybe_transpose_input_for_test_impl(tp, x, irreps):
558556 weights_dgrad , not tp .config .shared_weights
559557 )
560558
561- db_in1 = maybe_transpose_input_for_test_impl (tp , in1 , tp .config .irreps_in1 )
562- db_in2 = maybe_transpose_input_for_test_impl (tp , in2 , tp .config .irreps_in2 )
563- db_out_grad = out_grad
564- db_in1_dgrad = in1_dgrad
565- db_in2_dgrad = in2_dgrad
566- if tp is self and tp .config .layout == "ir_mul" :
567- db_out_grad = IrrepLayoutUtils .transpose_irrep_layout (
568- out_grad , tp .config .irreps_out , "mul_ir" , "ir_mul"
569- )
570- db_in1_dgrad = IrrepLayoutUtils .transpose_irrep_layout (
571- in1_dgrad , tp .config .irreps_in1 , "mul_ir" , "ir_mul"
572- )
573- db_in2_dgrad = IrrepLayoutUtils .transpose_irrep_layout (
574- in2_dgrad , tp .config .irreps_in2 , "mul_ir" , "ir_mul"
575- )
559+ db_in1 , db_in2 , db_out_grad , db_in1_dgrad , db_in2_dgrad = [
560+ buf .copy () for buf in (in1 , in2 , out_grad , in1_dgrad , in2_dgrad )
561+ ]
562+ if is_test_impl :
563+ db_in1 , db_in2 , db_out_grad , db_in1_dgrad , db_in2_dgrad = [
564+ IrrepLayoutUtils .transpose_irrep_layout (
565+ arr , irreps , "mul_ir" , tp .config .layout
566+ )
567+ for arr , irreps in zip (
568+ (db_in1 , db_in2 , db_out_grad , db_in1_dgrad , db_in2_dgrad ),
569+ (
570+ tp .config .irreps_in1 ,
571+ tp .config .irreps_in2 ,
572+ tp .config .irreps_out ,
573+ tp .config .irreps_in1 ,
574+ tp .config .irreps_in2 ,
575+ ),
576+ )
577+ ]
576578
577579 in1_grad , in2_grad , weights_grad , out_dgrad = tp .double_backward_cpu (
578580 db_in1 ,
@@ -585,16 +587,16 @@ def maybe_transpose_input_for_test_impl(tp, x, irreps):
585587 graph ,
586588 )
587589
588- if tp is self and tp . config . layout == "ir_mul" :
589- out_dgrad = IrrepLayoutUtils . transpose_irrep_layout (
590- out_dgrad , tp . config . irreps_out , "ir_mul" , "mul_ir"
591- )
592- in1_grad = IrrepLayoutUtils . transpose_irrep_layout (
593- in1_grad , tp . config . irreps_in1 , "ir_mul" , "mul_ir"
594- )
595- in2_grad = IrrepLayoutUtils . transpose_irrep_layout (
596- in2_grad , tp . config . irreps_in2 , "ir_mul" , "mul_ir"
597- )
590+ if is_test_impl :
591+ out_dgrad , in1_grad , in2_grad = [
592+ IrrepLayoutUtils . transpose_irrep_layout (
593+ arr , irreps , tp . config . layout , "mul_ir"
594+ )
595+ for arr , irreps in zip (
596+ ( out_dgrad , in1_grad , in2_grad ),
597+ ( tp . config . irreps_out , tp . config . irreps_in1 , tp . config . irreps_in2 ),
598+ )
599+ ]
598600
599601 tensors .append (
600602 (
0 commit comments