Skip to content

Commit 64c5c56

Browse files
committed
Reordering starting to work...
1 parent 6452140 commit 64c5c56

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

openequivariance/openequivariance/impl_jax/TensorProductConv.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph):
174174
rows = graph.rows.astype(np.int32)
175175
cols = graph.cols.astype(np.int32)
176176
sender_perm = graph.transpose_perm.astype(np.int32)
177+
weights = self.reorder_weights_from_e3nn(weights, has_batch_dim=not self.config.shared_weights)
177178
result = self.forward(
178179
jax.numpy.asarray(L1_in),
179180
jax.numpy.asarray(L2_in),
@@ -198,6 +199,7 @@ def backward_cpu(
198199
rows = graph.rows.astype(np.int32)
199200
cols = graph.cols.astype(np.int32)
200201
sender_perm = graph.transpose_perm.astype(np.int32)
202+
weights = self.reorder_weights_from_e3nn(weights, has_batch_dim=not self.config.shared_weights)
201203

202204
backward_fn = jax.vjp(
203205
lambda X, Y, W: self.forward(
@@ -217,4 +219,5 @@ def backward_cpu(
217219
)
218220
L1_grad[:] = np.asarray(L1_grad_jax)
219221
L2_grad[:] = np.asarray(L2_grad_jax)
220-
weights_grad[:] = np.asarray(weights_grad_jax)
222+
weights_grad[:] = np.asarray(weights_grad_jax)
223+
weights_grad[:] = self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights)

0 commit comments

Comments
 (0)