1616)
1717from openequivariance .core .e3nn_lite import TPProblem
1818from openequivariance .core .TensorProductBase import TensorProductBase
19- from openequivariance .core .utils import IrrepLayoutUtils
19+ from openequivariance .core .utils import transpose_irrep_layout
2020
2121logger = getLogger ()
2222
@@ -87,29 +87,31 @@ def correctness_forward(
8787 outputs = []
8888
8989 for i , impl in enumerate ([test_implementation , reference_implementation ]):
90- is_test_impl = ( i == 0 )
90+ is_test_impl = i == 0
9191 tp = instantiate_implementation (impl , problem )
9292 uses_cue = impl == CUETensorProduct or isinstance (tp , CUETensorProduct )
93- run_in1 , run_in2 , run_weights , run_out = [ buf .copy () for buf in (in1 , in2 , weights , out ) ]
93+ run_in1 , run_in2 , run_weights , run_out = [
94+ buf .copy () for buf in (in1 , in2 , weights , out )
95+ ]
9496
9597 if problem .shared_weights and uses_cue :
9698 run_weights = run_weights [np .newaxis , :]
9799
98- # Transpose inputs, if necessary, for the test implementation
100+ # Transpose inputs, if necessary, for the test implementation
99101 if is_test_impl :
100102 run_in1 , run_in2 = [
101- IrrepLayoutUtils .transpose_irrep_layout (
102- arr , irreps , "mul_ir" , tp .config .layout
103- ) for arr , irreps in zip (
104- (run_in1 , run_in2 ),
105- (problem .irreps_in1 , problem .irreps_in2 )
103+ transpose_irrep_layout (arr , irreps , "mul_ir" , tp .config .layout )
104+ for arr , irreps in zip (
105+ (run_in1 , run_in2 ), (problem .irreps_in1 , problem .irreps_in2 )
106106 )
107107 ]
108108
109- tp .forward_cpu (L1_in = run_in1 , L2_in = run_in2 , L3_out = run_out , weights = run_weights )
109+ tp .forward_cpu (
110+ L1_in = run_in1 , L2_in = run_in2 , L3_out = run_out , weights = run_weights
111+ )
110112
111113 if is_test_impl :
112- run_out = IrrepLayoutUtils . transpose_irrep_layout (
114+ run_out = transpose_irrep_layout (
113115 run_out , problem .irreps_out , tp .config .layout , "mul_ir"
114116 )
115117
@@ -147,7 +149,15 @@ def correctness_backward(
147149 is_test_impl = i == 0
148150 tp = instantiate_implementation (impl , problem )
149151
150- run_in1 , run_in2 , run_L3_grad , run_weights , run_weights_grad , run_in1_grad , run_in2_grad = [
152+ (
153+ run_in1 ,
154+ run_in2 ,
155+ run_L3_grad ,
156+ run_weights ,
157+ run_weights_grad ,
158+ run_in1_grad ,
159+ run_in2_grad ,
160+ ) = [
151161 buf .copy ()
152162 for buf in (in1 , in2 , out_grad , weights , weights_grad , in1_grad , in2_grad )
153163 ]
@@ -159,9 +169,7 @@ def correctness_backward(
159169
160170 if is_test_impl :
161171 run_in1 , run_in2 , run_L3_grad = [
162- IrrepLayoutUtils .transpose_irrep_layout (
163- arr , irreps , "mul_ir" , tp .config .layout
164- )
172+ transpose_irrep_layout (arr , irreps , "mul_ir" , tp .config .layout )
165173 for arr , irreps in zip (
166174 (run_in1 , run_in2 , run_L3_grad ),
167175 (problem .irreps_in1 , problem .irreps_in2 , problem .irreps_out ),
@@ -180,9 +188,7 @@ def correctness_backward(
180188
181189 if is_test_impl :
182190 run_in1_grad , run_in2_grad = [
183- IrrepLayoutUtils .transpose_irrep_layout (
184- arr , irreps , tp .config .layout , "mul_ir"
185- )
191+ transpose_irrep_layout (arr , irreps , tp .config .layout , "mul_ir" )
186192 for arr , irreps in zip (
187193 (run_in1_grad , run_in2_grad ),
188194 (problem .irreps_in1 , problem .irreps_in2 ),
@@ -254,9 +260,7 @@ def correctness_double_backward(
254260
255261 if is_test_impl :
256262 db_in1 , db_in2 , db_out_grad , db_in1_dgrad , db_in2_dgrad = [
257- IrrepLayoutUtils .transpose_irrep_layout (
258- arr , irreps , "mul_ir" , tp .config .layout
259- )
263+ transpose_irrep_layout (arr , irreps , "mul_ir" , tp .config .layout )
260264 for arr , irreps in zip (
261265 (db_in1 , db_in2 , db_out_grad , db_in1_dgrad , db_in2_dgrad ),
262266 (
@@ -281,9 +285,7 @@ def correctness_double_backward(
281285
282286 if is_test_impl :
283287 out_dgrad , in1_grad , in2_grad = [
284- IrrepLayoutUtils .transpose_irrep_layout (
285- arr , irreps , tp .config .layout , "mul_ir"
286- )
288+ transpose_irrep_layout (arr , irreps , tp .config .layout , "mul_ir" )
287289 for arr , irreps in zip (
288290 (out_dgrad , in1_grad , in2_grad ),
289291 (problem .irreps_out , problem .irreps_in1 , problem .irreps_in2 ),
@@ -359,9 +361,7 @@ def correctness_forward_conv(
359361
360362 if is_test_impl :
361363 run_in1 , run_in2 = [
362- IrrepLayoutUtils .transpose_irrep_layout (
363- arr , irreps , "mul_ir" , conv .config .layout
364- )
364+ transpose_irrep_layout (arr , irreps , "mul_ir" , conv .config .layout )
365365 for arr , irreps in zip (
366366 (run_in1 , run_in2 ),
367367 (conv .config .irreps_in1 , conv .config .irreps_in2 ),
@@ -375,7 +375,7 @@ def correctness_forward_conv(
375375 graph = graph ,
376376 )
377377
378- run_out = IrrepLayoutUtils . transpose_irrep_layout (
378+ run_out = transpose_irrep_layout (
379379 run_out , conv .config .irreps_out , conv .config .layout , "mul_ir"
380380 )
381381 else :
@@ -410,13 +410,9 @@ def correctness_forward_conv(
410410
411411 for _ in range (num_trials ):
412412 repeated_run = out .copy ()
413- rep_in1 , rep_in2 , rep_weights = [
414- buf .copy () for buf in (in1 , in2 , weights )
415- ]
413+ rep_in1 , rep_in2 , rep_weights = [buf .copy () for buf in (in1 , in2 , weights )]
416414 rep_in1 , rep_in2 = [
417- IrrepLayoutUtils .transpose_irrep_layout (
418- arr , irreps , "mul_ir" , conv .config .layout
419- )
415+ transpose_irrep_layout (arr , irreps , "mul_ir" , conv .config .layout )
420416 for arr , irreps in zip (
421417 (rep_in1 , rep_in2 ),
422418 (conv .config .irreps_in1 , conv .config .irreps_in2 ),
@@ -430,7 +426,7 @@ def correctness_forward_conv(
430426 graph = graph ,
431427 )
432428
433- repeated_run = IrrepLayoutUtils . transpose_irrep_layout (
429+ repeated_run = transpose_irrep_layout (
434430 repeated_run , conv .config .irreps_out , conv .config .layout , "mul_ir"
435431 )
436432
@@ -471,9 +467,15 @@ def correctness_backward_conv(
471467 is_test_impl = i == 0
472468 tp = impl if is_test_impl else impl (reference_problem )
473469
474- run_in1 , run_in2 , run_out_grad , run_weights , run_weights_grad , run_in1_grad , run_in2_grad = [
475- buf .copy () for buf in buffers
476- ]
470+ (
471+ run_in1 ,
472+ run_in2 ,
473+ run_out_grad ,
474+ run_weights ,
475+ run_weights_grad ,
476+ run_in1_grad ,
477+ run_in2_grad ,
478+ ) = [buf .copy () for buf in buffers ]
477479
478480 if not is_test_impl and high_precision_ref :
479481 (
@@ -488,12 +490,14 @@ def correctness_backward_conv(
488490
489491 if is_test_impl :
490492 run_in1 , run_in2 , run_out_grad = [
491- IrrepLayoutUtils .transpose_irrep_layout (
492- arr , irreps , "mul_ir" , conv .config .layout
493- )
493+ transpose_irrep_layout (arr , irreps , "mul_ir" , conv .config .layout )
494494 for arr , irreps in zip (
495495 (run_in1 , run_in2 , run_out_grad ),
496- (conv .config .irreps_in1 , conv .config .irreps_in2 , conv .config .irreps_out ),
496+ (
497+ conv .config .irreps_in1 ,
498+ conv .config .irreps_in2 ,
499+ conv .config .irreps_out ,
500+ ),
497501 )
498502 ]
499503
@@ -510,9 +514,7 @@ def correctness_backward_conv(
510514
511515 if is_test_impl :
512516 run_in1_grad , run_in2_grad = [
513- IrrepLayoutUtils .transpose_irrep_layout (
514- arr , irreps , conv .config .layout , "mul_ir"
515- )
517+ transpose_irrep_layout (arr , irreps , conv .config .layout , "mul_ir" )
516518 for arr , irreps in zip (
517519 (run_in1_grad , run_in2_grad ),
518520 (conv .config .irreps_in1 , conv .config .irreps_in2 ),
@@ -581,9 +583,7 @@ def correctness_double_backward_conv(
581583 ]
582584 if is_test_impl :
583585 db_in1 , db_in2 , db_out_grad , db_in1_dgrad , db_in2_dgrad = [
584- IrrepLayoutUtils .transpose_irrep_layout (
585- arr , irreps , "mul_ir" , tp .config .layout
586- )
586+ transpose_irrep_layout (arr , irreps , "mul_ir" , tp .config .layout )
587587 for arr , irreps in zip (
588588 (db_in1 , db_in2 , db_out_grad , db_in1_dgrad , db_in2_dgrad ),
589589 (
@@ -609,9 +609,7 @@ def correctness_double_backward_conv(
609609
610610 if is_test_impl :
611611 out_dgrad , in1_grad , in2_grad = [
612- IrrepLayoutUtils .transpose_irrep_layout (
613- arr , irreps , tp .config .layout , "mul_ir"
614- )
612+ transpose_irrep_layout (arr , irreps , tp .config .layout , "mul_ir" )
615613 for arr , irreps in zip (
616614 (out_dgrad , in1_grad , in2_grad ),
617615 (tp .config .irreps_out , tp .config .irreps_in1 , tp .config .irreps_in2 ),
0 commit comments