11import copy
2+
23import numpy as np
4+
5+ from openequivariance .benchmark .correctness_utils import check_similiarity
6+ from openequivariance .benchmark .logging_utils import bcolors , getLogger
37from openequivariance .benchmark .random_buffer_utils import (
4- get_random_buffers_forward_conv ,
58 get_random_buffers_backward_conv ,
69 get_random_buffers_double_backward_conv ,
10+ get_random_buffers_forward_conv ,
711)
8-
9- from openequivariance .benchmark .logging_utils import getLogger , bcolors
10- from openequivariance .benchmark .correctness_utils import check_similiarity
1112from openequivariance .core .e3nn_lite import wigner_3j
12- from openequivariance .core .utils import benchmark
13+ from openequivariance .core .utils import IrrepLayoutUtils , benchmark
1314
1415logger = getLogger ()
1516
@@ -143,6 +144,13 @@ def test_correctness_forward(
143144 check_reproducible = True ,
144145 high_precision_ref = False ,
145146 ):
147+ def maybe_transpose_input_for_test_impl (x , irreps ):
148+ if self .config .layout == "ir_mul" :
149+ return IrrepLayoutUtils .transpose_irrep_layout (
150+ x , irreps , "mul_ir" , "ir_mul"
151+ )
152+ return x
153+
146154 if reference_implementation is None :
147155 from openequivariance ._torch .E3NNConv import E3NNConv
148156
@@ -186,13 +194,22 @@ def test_correctness_forward(
186194
187195 test_out = out .copy ()
188196 self .forward_cpu (
189- L1_in = in1 .copy (),
190- L2_in = in2 .copy (),
197+ L1_in = maybe_transpose_input_for_test_impl (
198+ in1 .copy (), self .config .irreps_in1
199+ ),
200+ L2_in = maybe_transpose_input_for_test_impl (
201+ in2 .copy (), self .config .irreps_in2
202+ ),
191203 weights = weights .copy (),
192204 L3_out = test_out ,
193205 graph = graph ,
194206 )
195207
208+ if self .config .layout == "ir_mul" :
209+ test_out = IrrepLayoutUtils .transpose_irrep_layout (
210+ test_out , self .config .irreps_out , "ir_mul" , "mul_ir"
211+ )
212+
196213 for name , to_check , ground_truth in [("output" , ref_out , test_out )]:
197214 result [name ] = check_similiarity (name , to_check , ground_truth , thresh )
198215
@@ -205,13 +222,22 @@ def test_correctness_forward(
205222 for i in range (num_trials ):
206223 repeated_run = out .copy ()
207224 self .forward_cpu (
208- L1_in = in1 .copy (),
209- L2_in = in2 .copy (),
225+ L1_in = maybe_transpose_input_for_test_impl (
226+ in1 .copy (), self .config .irreps_in1
227+ ),
228+ L2_in = maybe_transpose_input_for_test_impl (
229+ in2 .copy (), self .config .irreps_in2
230+ ),
210231 weights = weights .copy (),
211232 L3_out = repeated_run ,
212233 graph = graph ,
213234 )
214235
236+ if self .config .layout == "ir_mul" :
237+ repeated_run = IrrepLayoutUtils .transpose_irrep_layout (
238+ repeated_run , self .config .irreps_out , "ir_mul" , "mul_ir"
239+ )
240+
215241 for name , to_check , ground_truth in [
216242 ("output" , repeated_run , test_out )
217243 ]:
@@ -387,6 +413,13 @@ def test_correctness_backward(
387413 reference_implementation = None ,
388414 high_precision_ref = False ,
389415 ):
416+ def maybe_transpose_input_for_test_impl (x , irreps ):
417+ if self .config .layout == "ir_mul" :
418+ return IrrepLayoutUtils .transpose_irrep_layout (
419+ x , irreps , "mul_ir" , "ir_mul"
420+ )
421+ return x
422+
390423 if reference_implementation is None :
391424 from openequivariance ._torch .E3NNConv import E3NNConv
392425
@@ -436,17 +469,35 @@ def test_correctness_backward(
436469 test_in1_grad = in1_grad .copy ()
437470 test_in2_grad = in2_grad .copy ()
438471
472+ test_L3_grad = out_grad .copy ()
473+ if self .config .layout == "ir_mul" :
474+ test_L3_grad = IrrepLayoutUtils .transpose_irrep_layout (
475+ test_L3_grad , self .config .irreps_out , "mul_ir" , "ir_mul"
476+ )
477+
439478 self .backward_cpu (
440- L1_in = in1 .copy (),
479+ L1_in = maybe_transpose_input_for_test_impl (
480+ in1 .copy (), self .config .irreps_in1
481+ ),
441482 L1_grad = test_in1_grad ,
442- L2_in = in2 .copy (),
483+ L2_in = maybe_transpose_input_for_test_impl (
484+ in2 .copy (), self .config .irreps_in2
485+ ),
443486 L2_grad = test_in2_grad ,
444- L3_grad = out_grad . copy () ,
487+ L3_grad = test_L3_grad ,
445488 weights = weights .copy (),
446489 weights_grad = test_weights_grad ,
447490 graph = graph ,
448491 )
449492
493+ if self .config .layout == "ir_mul" :
494+ test_in1_grad = IrrepLayoutUtils .transpose_irrep_layout (
495+ test_in1_grad , self .config .irreps_in1 , "ir_mul" , "mul_ir"
496+ )
497+ test_in2_grad = IrrepLayoutUtils .transpose_irrep_layout (
498+ test_in2_grad , self .config .irreps_in2 , "ir_mul" , "mul_ir"
499+ )
500+
450501 for name , to_check , ground_truth , threshold in [
451502 ("weight_grad" , test_weights_grad , ref_weights_grad , thresh ),
452503 ("in1_grad" , test_in1_grad , ref_in1_grad , thresh ),
@@ -464,6 +515,13 @@ def test_correctness_double_backward(
464515 reference_implementation = None ,
465516 high_precision_ref = False ,
466517 ):
518+ def maybe_transpose_input_for_test_impl (tp , x , irreps ):
519+ if tp is self and tp .config .layout == "ir_mul" :
520+ return IrrepLayoutUtils .transpose_irrep_layout (
521+ x , irreps , "mul_ir" , "ir_mul"
522+ )
523+ return x
524+
467525 buffers = get_random_buffers_double_backward_conv (
468526 self .config , graph .node_count , graph .nnz , prng_seed
469527 )
@@ -500,17 +558,44 @@ def test_correctness_double_backward(
500558 weights_dgrad , not tp .config .shared_weights
501559 )
502560
561+ db_in1 = maybe_transpose_input_for_test_impl (tp , in1 , tp .config .irreps_in1 )
562+ db_in2 = maybe_transpose_input_for_test_impl (tp , in2 , tp .config .irreps_in2 )
563+ db_out_grad = out_grad
564+ db_in1_dgrad = in1_dgrad
565+ db_in2_dgrad = in2_dgrad
566+ if tp is self and tp .config .layout == "ir_mul" :
567+ db_out_grad = IrrepLayoutUtils .transpose_irrep_layout (
568+ out_grad , tp .config .irreps_out , "mul_ir" , "ir_mul"
569+ )
570+ db_in1_dgrad = IrrepLayoutUtils .transpose_irrep_layout (
571+ in1_dgrad , tp .config .irreps_in1 , "mul_ir" , "ir_mul"
572+ )
573+ db_in2_dgrad = IrrepLayoutUtils .transpose_irrep_layout (
574+ in2_dgrad , tp .config .irreps_in2 , "mul_ir" , "ir_mul"
575+ )
576+
503577 in1_grad , in2_grad , weights_grad , out_dgrad = tp .double_backward_cpu (
504- in1 ,
505- in2 ,
506- out_grad ,
578+ db_in1 ,
579+ db_in2 ,
580+ db_out_grad ,
507581 weights_reordered ,
508582 weights_dgrad_reordered ,
509- in1_dgrad ,
510- in2_dgrad ,
583+ db_in1_dgrad ,
584+ db_in2_dgrad ,
511585 graph ,
512586 )
513587
588+ if tp is self and tp .config .layout == "ir_mul" :
589+ out_dgrad = IrrepLayoutUtils .transpose_irrep_layout (
590+ out_dgrad , tp .config .irreps_out , "ir_mul" , "mul_ir"
591+ )
592+ in1_grad = IrrepLayoutUtils .transpose_irrep_layout (
593+ in1_grad , tp .config .irreps_in1 , "ir_mul" , "mul_ir"
594+ )
595+ in2_grad = IrrepLayoutUtils .transpose_irrep_layout (
596+ in2_grad , tp .config .irreps_in2 , "ir_mul" , "mul_ir"
597+ )
598+
514599 tensors .append (
515600 (
516601 out_dgrad ,
0 commit comments