@@ -117,6 +117,18 @@ def __init__(
117117 self .workspace_ptr = 0
118118 self .workspace_size = 0
119119
120+ def reorder_weights_from_e3nn (self , weights , has_batch_dim = True ):
121+ r"""
122+ See :py:func:`oeq.TensorProduct.reorder_weights_from_e3nn`.
123+ """
124+ return weights
125+
126+ def reorder_weights_to_e3nn (self , weights , has_batch_dim = True ):
127+ r"""
128+ See :py:func:`oeq.TensorProduct.reorder_weights_to_e3nn`.
129+ """
130+ return weights
131+
120132 def allocate_workspace (self , size_bytes ):
121133 self .workspace_size = size_bytes
122134 if self .torch_op :
@@ -136,13 +148,9 @@ def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph):
136148 assert graph .rows .dtype == self .idx_dtype
137149 assert graph .cols .dtype == self .idx_dtype
138150
139- weights_chunked = np .zeros_like (weights )
140- if self .reorder_weights_e3nn_to_oeq is not None :
141- self .reorder_weights_e3nn_to_oeq (
142- weights , weights_chunked , not self .config .shared_weights
143- )
144- else :
145- weights_chunked = weights
151+ weights_chunked = self .reorder_weights_from_e3nn (
152+ weights , not self .config .shared_weights
153+ )
146154
147155 L1_d , L2_d , weights_d = (
148156 DeviceBuffer (L1_in ),
@@ -174,13 +182,9 @@ def backward_cpu(
174182 assert graph .rows .dtype == self .idx_dtype
175183 assert graph .cols .dtype == self .idx_dtype
176184
177- weights_chunked = np .zeros_like (weights )
178- if self .reorder_weights_e3nn_to_oeq is not None :
179- self .reorder_weights_e3nn_to_oeq (
180- weights , weights_chunked , not self .config .shared_weights
181- )
182- else :
183- weights_chunked = weights
185+ weights_chunked = self .reorder_weights_from_e3nn (
186+ weights , not self .config .shared_weights
187+ )
184188
185189 L1_d = DeviceBuffer (L1_in )
186190 L2_d = DeviceBuffer (L2_in )
@@ -219,11 +223,9 @@ def backward_cpu(
219223 L2_grad_d .copy_to_host ()
220224 weights_grad_d .copy_to_host ()
221225
222- if self .reorder_weights_oeq_to_e3nn is not None :
223- weights_grad_copy = weights_grad .copy ()
224- self .reorder_weights_oeq_to_e3nn (
225- weights_grad_copy , weights_grad , not self .config .shared_weights
226- )
226+ weights_grad [:] = self .reorder_weights_to_e3nn (
227+ weights_grad , not self .config .shared_weights
228+ )
227229
228230 return L1_grad , L2_grad , weights_grad
229231
@@ -712,17 +714,10 @@ def test_correctness_double_backward(
712714 in1_torch = torch .tensor (in1 , device = "cuda" , requires_grad = True )
713715 in2_torch = torch .tensor (in2 , device = "cuda" , requires_grad = True )
714716
715- weights_reordered = np .zeros_like (weights )
716- if (
717- i == 0
718- and hasattr (self , "reorder_weights_e3nn_to_oeq" )
719- and self .reorder_weights_e3nn_to_oeq is not None
720- ):
721- self .reorder_weights_e3nn_to_oeq (
722- weights , weights_reordered , not self .config .shared_weights
723- )
724- else :
725- weights_reordered [:] = weights
717+ weights_reordered = tp .reorder_weights_from_e3nn (
718+ weights , not self .config .shared_weights
719+ )
720+
726721 weights_torch = torch .tensor (
727722 weights_reordered , device = "cuda" , requires_grad = True
728723 )
@@ -754,15 +749,9 @@ def test_correctness_double_backward(
754749 )
755750
756751 weights_grad = weights_torch .grad .detach ().cpu ().numpy ()
757- if (
758- i == 0
759- and hasattr (self , "reorder_weights_e3nn_to_oeq" )
760- and self .reorder_weights_oeq_to_e3nn is not None
761- ):
762- weights_grad_copy = weights_grad .copy ()
763- self .reorder_weights_oeq_to_e3nn (
764- weights_grad_copy , weights_grad , not self .config .shared_weights
765- )
752+ weights_grad = tp .reorder_weights_to_e3nn (
753+ weights_grad , not self .config .shared_weights
754+ )
766755
767756 tensors .append (
768757 (
0 commit comments