Skip to content

Commit c3f83ea

Browse files
committed
Batch test is working.
1 parent d815424 commit c3f83ea

2 files changed

Lines changed: 18 additions & 10 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
9090
return reorder_jax(self.forward_schedule, weights, "backward", not self.config.shared_weights)
9191

9292
def forward_cpu(self, L1_in, L2_in, L3_out, weights) -> None:
93+
weights = self.reorder_weights_from_e3nn(weights, has_batch_dim=not self.config.shared_weights)
9394
result = self.forward(
9495
jax.numpy.asarray(L1_in),
9596
jax.numpy.asarray(L2_in),
@@ -100,6 +101,7 @@ def forward_cpu(self, L1_in, L2_in, L3_out, weights) -> None:
100101
def backward_cpu(
101102
self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad
102103
) -> None:
104+
weights = self.reorder_weights_from_e3nn(weights, has_batch_dim=not self.config.shared_weights)
103105
backward_fn = jax.vjp(
104106
lambda X, Y, W: self.forward(X, Y, W),
105107
jax.numpy.asarray(L1_in),
@@ -111,4 +113,5 @@ def backward_cpu(
111113
)
112114
L1_grad[:] = np.asarray(L1_grad_jax)
113115
L2_grad[:] = np.asarray(L2_grad_jax)
114-
weights_grad[:] = np.asarray(weights_grad_jax)
116+
weights_grad[:] = np.asarray(weights_grad_jax)
117+
weights_grad[:] = self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights)

tests/batch_test.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,9 @@ def problem(self, request, dtype):
254254

255255
class TestTorchbindDisable(TestProductionModels):
256256
@pytest.fixture(scope="class")
257-
def extra_tp_constructor_args(self):
257+
def extra_tp_constructor_args(self, test_jax):
258+
if test_jax:
259+
pytest.skip("N/A for JAX")
258260
return {"use_opaque": True}
259261

260262

@@ -268,11 +270,14 @@ def problem(self, request, dtype):
268270
return problem
269271

270272
@pytest.fixture(scope="class")
271-
def tp_and_problem(self, problem, extra_tp_constructor_args):
272-
tp = TensorProduct(problem, **extra_tp_constructor_args)
273-
switch_map = {
274-
np.float32: torch.float64,
275-
np.float64: torch.float32,
276-
}
277-
tp.to(switch_map[problem.irrep_dtype])
278-
return tp, tp.config
273+
def tp_and_problem(self, problem, extra_tp_constructor_args, test_jax):
274+
if test_jax:
275+
pytest.skip("N/A for JAX")
276+
else:
277+
tp = oeq.TensorProduct(problem, **extra_tp_constructor_args)
278+
switch_map = {
279+
np.float32: torch.float64,
280+
np.float64: torch.float32,
281+
}
282+
tp.to(switch_map[problem.irrep_dtype])
283+
return tp, tp.config

0 commit comments

Comments
 (0)