11import torch
22
3+ from openequivariance .core .utils import IrrepLayoutUtils
4+
35
46class NumpyDoubleBackwardMixin :
57 """
@@ -13,12 +15,30 @@ def double_backward_cpu(
1315 ):
1416 assert self .torch_op
1517
16- in1_torch = torch .tensor (in1 ).to ("cuda" ).requires_grad_ (True )
17- in2_torch = torch .tensor (in2 ).to ("cuda" ).requires_grad_ (True )
18+ layout = self .config .layout
19+
20+ in1_kernel = IrrepLayoutUtils .transpose_irrep_layout (
21+ in1 , self .config .irreps_in1 , layout , "mul_ir"
22+ )
23+ in2_kernel = IrrepLayoutUtils .transpose_irrep_layout (
24+ in2 , self .config .irreps_in2 , layout , "mul_ir"
25+ )
26+ out_grad_kernel = IrrepLayoutUtils .transpose_irrep_layout (
27+ out_grad , self .config .irreps_out , layout , "mul_ir"
28+ )
29+ in1_dgrad_kernel = IrrepLayoutUtils .transpose_irrep_layout (
30+ in1_dgrad , self .config .irreps_in1 , layout , "mul_ir"
31+ )
32+ in2_dgrad_kernel = IrrepLayoutUtils .transpose_irrep_layout (
33+ in2_dgrad , self .config .irreps_in2 , layout , "mul_ir"
34+ )
35+
36+ in1_torch = torch .tensor (in1_kernel ).to ("cuda" ).requires_grad_ (True )
37+ in2_torch = torch .tensor (in2_kernel ).to ("cuda" ).requires_grad_ (True )
1838 weights_torch = torch .tensor (weights ).to ("cuda" ).requires_grad_ (True )
19- out_grad_torch = torch .tensor (out_grad ).to ("cuda" ).requires_grad_ (True )
20- in1_dgrad_torch = torch .tensor (in1_dgrad ).to ("cuda" )
21- in2_dgrad_torch = torch .tensor (in2_dgrad ).to ("cuda" )
39+ out_grad_torch = torch .tensor (out_grad_kernel ).to ("cuda" ).requires_grad_ (True )
40+ in1_dgrad_torch = torch .tensor (in1_dgrad_kernel ).to ("cuda" )
41+ in2_dgrad_torch = torch .tensor (in2_dgrad_kernel ).to ("cuda" )
2242 weights_dgrad_torch = torch .tensor (weights_dgrad ).to ("cuda" )
2343 out_torch = self .forward (in1_torch , in2_torch , weights_torch )
2444
@@ -36,12 +56,22 @@ def double_backward_cpu(
3656 grad_outputs = [in1_dgrad_torch , in2_dgrad_torch , weights_dgrad_torch ],
3757 )
3858
39- return (
40- a .detach ().cpu ().numpy (),
41- b .detach ().cpu ().numpy (),
42- c .detach ().cpu ().numpy (),
43- d .detach ().cpu ().numpy (),
59+ a_np = a .detach ().cpu ().numpy ()
60+ b_np = b .detach ().cpu ().numpy ()
61+ c_np = c .detach ().cpu ().numpy ()
62+ d_np = d .detach ().cpu ().numpy ()
63+
64+ a_np = IrrepLayoutUtils .transpose_irrep_layout (
65+ a_np , self .config .irreps_in1 , "mul_ir" , layout
4466 )
67+ b_np = IrrepLayoutUtils .transpose_irrep_layout (
68+ b_np , self .config .irreps_in2 , "mul_ir" , layout
69+ )
70+ d_np = IrrepLayoutUtils .transpose_irrep_layout (
71+ d_np , self .config .irreps_out , "mul_ir" , layout
72+ )
73+
74+ return (a_np , b_np , c_np , d_np )
4575
4676
4777class NumpyDoubleBackwardMixinConv :
@@ -54,12 +84,30 @@ def double_backward_cpu(
5484 ):
5585 assert self .torch_op
5686
57- in1_torch = torch .tensor (in1 ).to ("cuda" ).requires_grad_ (True )
58- in2_torch = torch .tensor (in2 ).to ("cuda" ).requires_grad_ (True )
87+ layout = self .config .layout
88+
89+ in1_kernel = IrrepLayoutUtils .transpose_irrep_layout (
90+ in1 , self .config .irreps_in1 , layout , "mul_ir"
91+ )
92+ in2_kernel = IrrepLayoutUtils .transpose_irrep_layout (
93+ in2 , self .config .irreps_in2 , layout , "mul_ir"
94+ )
95+ out_grad_kernel = IrrepLayoutUtils .transpose_irrep_layout (
96+ out_grad , self .config .irreps_out , layout , "mul_ir"
97+ )
98+ in1_dgrad_kernel = IrrepLayoutUtils .transpose_irrep_layout (
99+ in1_dgrad , self .config .irreps_in1 , layout , "mul_ir"
100+ )
101+ in2_dgrad_kernel = IrrepLayoutUtils .transpose_irrep_layout (
102+ in2_dgrad , self .config .irreps_in2 , layout , "mul_ir"
103+ )
104+
105+ in1_torch = torch .tensor (in1_kernel ).to ("cuda" ).requires_grad_ (True )
106+ in2_torch = torch .tensor (in2_kernel ).to ("cuda" ).requires_grad_ (True )
59107 weights_torch = torch .tensor (weights ).to ("cuda" ).requires_grad_ (True )
60- out_grad_torch = torch .tensor (out_grad ).to ("cuda" ).requires_grad_ (True )
61- in1_dgrad_torch = torch .tensor (in1_dgrad ).to ("cuda" )
62- in2_dgrad_torch = torch .tensor (in2_dgrad ).to ("cuda" )
108+ out_grad_torch = torch .tensor (out_grad_kernel ).to ("cuda" ).requires_grad_ (True )
109+ in1_dgrad_torch = torch .tensor (in1_dgrad_kernel ).to ("cuda" )
110+ in2_dgrad_torch = torch .tensor (in2_dgrad_kernel ).to ("cuda" )
63111 weights_dgrad_torch = torch .tensor (weights_dgrad ).to ("cuda" )
64112
65113 torch_rows = torch .tensor (graph .rows , device = "cuda" )
@@ -89,9 +137,19 @@ def double_backward_cpu(
89137 grad_outputs = [in1_dgrad_torch , in2_dgrad_torch , weights_dgrad_torch ],
90138 )
91139
92- return (
93- a .detach ().cpu ().numpy (),
94- b .detach ().cpu ().numpy (),
95- c .detach ().cpu ().numpy (),
96- d .detach ().cpu ().numpy (),
140+ a_np = a .detach ().cpu ().numpy ()
141+ b_np = b .detach ().cpu ().numpy ()
142+ c_np = c .detach ().cpu ().numpy ()
143+ d_np = d .detach ().cpu ().numpy ()
144+
145+ a_np = IrrepLayoutUtils .transpose_irrep_layout (
146+ a_np , self .config .irreps_in1 , "mul_ir" , layout
147+ )
148+ b_np = IrrepLayoutUtils .transpose_irrep_layout (
149+ b_np , self .config .irreps_in2 , "mul_ir" , layout
97150 )
151+ d_np = IrrepLayoutUtils .transpose_irrep_layout (
152+ d_np , self .config .irreps_out , "mul_ir" , layout
153+ )
154+
155+ return (a_np , b_np , c_np , d_np )
0 commit comments