Skip to content

Commit fa42654

Browse files
committed
1/3 tests is passing.
1 parent f16c622 commit fa42654

4 files changed

Lines changed: 39 additions & 5 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import jax
2+
import numpy as np
23
from functools import partial
34
from openequivariance.impl_jax import extlib
45
from openequivariance.core.e3nn_lite import TPProblem
@@ -74,10 +75,24 @@ def __init__(self, config: TPProblem):
7475
self.weight_numel = config.weight_numel
7576
self.L3_dim = self.config.irreps_out.dim
7677

77-
def forward(self, X: jax.ndarray, Y: jax.ndarray, W: jax.ndarray) -> jax.ndarray:
78+
def forward(self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray) -> jax.numpy.ndarray:
7879
return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs)
7980

8081
def __call__(
8182
self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray
8283
) -> jax.numpy.ndarray:
8384
return self.forward(X, Y, W)
85+
86+
def forward_cpu(
87+
self,
88+
L1_in: np.ndarray,
89+
L2_in: np.ndarray,
90+
L3_out: np.ndarray,
91+
weights: np.ndarray,
92+
) -> None:
93+
result = self.forward(
94+
jax.numpy.asarray(L1_in),
95+
jax.numpy.asarray(L2_in),
96+
jax.numpy.asarray(weights),
97+
)
98+
L3_out[:] = np.asarray(result)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from openequivariance.impl_jax.TensorProduct import TensorProduct as TensorProduct
2+
from openequivariance.impl_jax.TensorProductConv import TensorProductConv as TensorProductConv
3+
4+
__all__ = ["TensorProduct", "TensorProductConv"]

tests/batch_test.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from pytest_check import check
33

44
import numpy as np
5+
import openequivariance
56
import openequivariance as oeq
6-
from openequivariance.impl_torch.TensorProduct import TensorProduct
77
from openequivariance.benchmark.correctness_utils import (
88
correctness_forward,
99
correctness_backward,
@@ -19,7 +19,6 @@
1919
from itertools import product
2020
import torch
2121

22-
2322
class TPCorrectness:
2423
def thresh(self, direction):
2524
return {"fwd": 1e-5, "bwd": 3e-4, "double_bwd": 3e-4}[direction]
@@ -41,8 +40,16 @@ def extra_tp_constructor_args(self):
4140
return {}
4241

4342
@pytest.fixture(scope="class")
44-
def tp_and_problem(self, problem, extra_tp_constructor_args):
45-
tp = TensorProduct(problem, **extra_tp_constructor_args)
43+
def test_jax(self, request):
44+
return request.config.getoption("--jax")
45+
46+
@pytest.fixture(scope="class")
47+
def tp_and_problem(self, problem, extra_tp_constructor_args, test_jax):
48+
cls = oeq.TensorProduct
49+
if test_jax:
50+
import openequivariance.impl_jax.TensorProduct as jax_tp
51+
cls = jax_tp
52+
tp = cls(problem, **extra_tp_constructor_args)
4653
return tp, problem
4754

4855
def test_tp_fwd(self, tp_and_problem):

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import pytest
2+
import os
3+
4+
os.environ["JAX_ENABLE_X64"] = "True"
5+
def pytest_addoption(parser):
6+
parser.addoption(
7+
"--jax", action="store", default=False, help="Test the JAX frontend instead of PyTorch"
8+
)

0 commit comments

Comments
 (0)