Skip to content

Commit 617d996

Browse files
committed
Backward convolution is failing, need to figure out why.
1 parent 7f4ac06 commit 617d996

4 files changed

Lines changed: 65 additions & 7 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,5 +107,3 @@ def backward_cpu(
107107
L2_grad[:] = np.asarray(L2_grad_jax)
108108
weights_grad[:] = np.asarray(weights_grad_jax)
109109

110-
111-

openequivariance/openequivariance/impl_jax/TensorProductConv.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,52 @@ def __call__(
162162
sender_perm: Optional[jax.numpy.ndarray] = None,
163163
) -> jax.numpy.ndarray:
164164
return self.forward(X, Y, W, rows, cols, sender_perm)
165+
166+
def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph):
167+
rows = graph.rows.astype(np.int32)
168+
cols = graph.cols.astype(np.int32)
169+
sender_perm = graph.transpose_perm.astype(np.int32)
170+
result = self.forward(
171+
jax.numpy.asarray(L1_in),
172+
jax.numpy.asarray(L2_in),
173+
jax.numpy.asarray(weights),
174+
jax.numpy.asarray(rows),
175+
jax.numpy.asarray(cols),
176+
jax.numpy.asarray(sender_perm),
177+
)
178+
L3_out[:] = np.asarray(result)
179+
180+
def backward_cpu(
181+
self,
182+
L1_in,
183+
L1_grad,
184+
L2_in,
185+
L2_grad,
186+
L3_grad,
187+
weights,
188+
weights_grad,
189+
graph,
190+
):
191+
rows = graph.rows.astype(np.int32)
192+
cols = graph.cols.astype(np.int32)
193+
sender_perm = graph.transpose_perm.astype(np.int32)
194+
195+
backward_fn = jax.vjp(
196+
lambda X, Y, W: self.forward(
197+
X,
198+
Y,
199+
W,
200+
jax.numpy.asarray(rows),
201+
jax.numpy.asarray(cols),
202+
jax.numpy.asarray(sender_perm),
203+
),
204+
jax.numpy.asarray(L1_in),
205+
jax.numpy.asarray(L2_in),
206+
jax.numpy.asarray(weights),
207+
)[1]
208+
L1_grad_jax, L2_grad_jax, weights_grad_jax = backward_fn(
209+
jax.numpy.asarray(L3_grad)
210+
)
211+
L1_grad[:] = np.asarray(L1_grad_jax)
212+
L2_grad[:] = np.asarray(L2_grad_jax)
213+
weights_grad[:] = np.asarray(weights_grad_jax)

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import os
33

44
os.environ["JAX_ENABLE_X64"] = "True"
5+
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
56
def pytest_addoption(parser):
67
parser.addoption(
7-
"--jax", action="store", default=False, help="Test the JAX frontend instead of PyTorch"
8+
"--jax", action="store_true", default=False, help="Test the JAX frontend instead of PyTorch"
89
)

tests/conv_test.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pytest_check import check
55

66
import numpy as np
7+
import openequivariance
78
import openequivariance as oeq
89
from openequivariance.benchmark.ConvBenchmarkSuite import load_graph
910
from itertools import product
@@ -51,23 +52,32 @@ def graph(self, request):
5152
def extra_conv_constructor_args(self):
5253
return {}
5354

55+
@pytest.fixture(scope="class")
56+
def test_jax(self, request):
57+
return request.config.getoption("--jax")
58+
5459
@pytest.fixture(params=["atomic", "deterministic", "kahan"], scope="class")
55-
def conv_object(self, request, problem, extra_conv_constructor_args):
60+
def conv_object(self, request, problem, extra_conv_constructor_args, test_jax):
61+
cls = oeq.TensorProductConv
62+
if test_jax:
63+
from openequivariance.impl_jax import TensorProductConv as jax_conv
64+
cls = jax_conv
65+
5666
if request.param == "atomic":
57-
return oeq.TensorProductConv(
67+
return cls(
5868
problem, deterministic=False, **extra_conv_constructor_args
5969
)
6070
elif request.param == "deterministic":
6171
if not problem.shared_weights:
62-
return oeq.TensorProductConv(
72+
return cls(
6373
problem, deterministic=True, **extra_conv_constructor_args
6474
)
6575
else:
6676
pytest.skip("Shared weights not supported with deterministic")
6777
elif request.param == "kahan":
6878
if problem.irrep_dtype == np.float32:
6979
if not problem.shared_weights:
70-
return oeq.TensorProductConv(
80+
return cls(
7181
problem,
7282
deterministic=True,
7383
kahan=True,

0 commit comments

Comments
 (0)