Skip to content

Commit 61e0566

Browse files
committed
Almost there, need to get TensorProductConv working.
1 parent 8a0094a commit 61e0566

4 files changed

Lines changed: 9 additions & 10 deletions

File tree

openequivariance/openequivariance/core/ConvolutionBase.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -578,19 +578,19 @@ def test_correctness_double_backward(
578578
in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = buffers_copy
579579

580580
weights_reordered = tp.reorder_weights_from_e3nn(
581-
weights, not self.config.shared_weights
581+
weights, not tp.config.shared_weights
582582
)
583583
weights_dgrad_reordered = tp.reorder_weights_from_e3nn(
584-
weights_dgrad, not self.config.shared_weights
584+
weights_dgrad, not tp.config.shared_weights
585585
)
586586

587-
in1_grad, in2_grad, weights_grad, out_dgrad = self.double_backward_cpu(in1, in2, out_grad, weights_reordered, weights_dgrad_reordered, in1_dgrad, in2_dgrad, graph)
587+
in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights_reordered, weights_dgrad_reordered, in1_dgrad, in2_dgrad, graph)
588588

589589
tensors.append(
590590
( out_dgrad,
591591
in1_grad,
592592
in2_grad,
593-
self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights)
593+
tp.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights)
594594
))
595595

596596
for name, to_check, ground_truth in [

openequivariance/openequivariance/impl_jax/TensorProductConv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,10 @@ def __call__(
165165
return self.forward(X, Y, W, rows, cols, sender_perm)
166166

167167
def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
168-
return reorder_jax(self.forward_schedule, weights, "forward", not self.config.shared_weights)
168+
return reorder_jax(self.forward_schedule, weights, "forward", has_batch_dim)
169169

170170
def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
171-
return reorder_jax(self.forward_schedule, weights, "backward", not self.config.shared_weights)
171+
return reorder_jax(self.forward_schedule, weights, "backward", has_batch_dim)
172172

173173
def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph):
174174
rows = graph.rows.astype(np.int32)
@@ -240,4 +240,4 @@ def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dg
240240
in1_jax, in2_jax, weights_jax, out_grad_jax
241241
)[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax))
242242

243-
return in1_grad, in2_grad, weights_grad, out_dgrad
243+
return np.asarray(in1_grad), np.asarray(in2_grad), np.asarray(weights_grad), np.asarray(out_dgrad)

openequivariance/openequivariance/impl_torch/E3NNConv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, config, *, idx_dtype=np.int64, torch_op=True):
3737
if config.irrep_dtype == np.float64:
3838
torch.set_default_dtype(torch.float32) # Reset to default
3939

40-
def forward(self, L1_in, L2_in, weights, rows, cols):
40+
def forward(self, L1_in, L2_in, weights, rows, cols, transpose_perm=None):
4141
messages = self.reference_tp(L1_in[cols], L2_in, weights)
4242
return scatter_add_wrapper(messages, rows, L1_in.size(0))
4343

openequivariance/pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ bench = [
4848
"cuequivariance-ops-torch-cu12",
4949
]
5050

51-
jax = [
52-
"jax[cuda12]",
51+
jax = [
5352
"nanobind",
5453
"scikit-build-core"
5554
]

0 commit comments

Comments
 (0)