Skip to content

Commit b7af425

Browse files
committed
Renamed.
1 parent 9d6e30e commit b7af425

10 files changed

Lines changed: 12 additions & 13 deletions

File tree

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
"torch",
3939
"jax",
4040
"openequivariance.impl_torch.extlib",
41-
"openequivariance.impl_jax.extlib",
41+
"openequivariance.jax.extlib",
4242
"openequivariance_extjax",
4343
"jinja2",
4444
"numpy",

openequivariance/openequivariance/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def torch_ext_so_path():
7272
jax = None
7373
try:
7474
import openequivariance_extjax
75-
import openequivariance.impl_jax as jax
75+
import openequivariance.jax as jax
7676
except ImportError:
7777
pass
7878

openequivariance/openequivariance/impl_jax/__init__.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

openequivariance/openequivariance/impl_jax/TensorProduct.py renamed to openequivariance/openequivariance/jax/TensorProduct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import jax
22
import numpy as np
33
from functools import partial
4-
from openequivariance.impl_jax import extlib
4+
from openequivariance.jax import extlib
55
from openequivariance.core.e3nn_lite import TPProblem
66
from openequivariance.core.LoopUnrollTP import LoopUnrollTP
77
from openequivariance.core.utils import hash_attributes
8-
from openequivariance.impl_jax.utils import reorder_jax
8+
from openequivariance.jax.utils import reorder_jax
99

1010
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
1111
def forward(X, Y, W, L3_dim, irrep_dtype, attrs):

openequivariance/openequivariance/impl_jax/TensorProductConv.py renamed to openequivariance/openequivariance/jax/TensorProductConv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import numpy as np
22
from functools import partial
33
from typing import Optional
4-
from openequivariance.impl_jax import extlib
4+
from openequivariance.jax import extlib
55

66
from openequivariance.core.e3nn_lite import TPProblem
77
from openequivariance.core.LoopUnrollConv import LoopUnrollConv
88
from openequivariance.core.utils import hash_attributes
9-
from openequivariance.impl_jax.utils import reorder_jax
9+
from openequivariance.jax.utils import reorder_jax
1010

1111
import jax
1212
import jax.numpy as jnp
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from openequivariance.jax.TensorProduct import TensorProduct as TensorProduct
2+
from openequivariance.jax.TensorProductConv import TensorProductConv as TensorProductConv
3+
4+
__all__ = ["TensorProduct", "TensorProductConv"]

openequivariance/openequivariance/impl_jax/extlib/__init__.py renamed to openequivariance/openequivariance/jax/extlib/__init__.py

File renamed without changes.
File renamed without changes.

tests/batch_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_jax(self, request):
4747
def tp_and_problem(self, problem, extra_tp_constructor_args, test_jax):
4848
cls = oeq.TensorProduct
4949
if test_jax:
50-
import openequivariance.impl_jax.TensorProduct as jax_tp
50+
import openequivariance.jax.TensorProduct as jax_tp
5151
cls = jax_tp
5252
tp = cls(problem, **extra_tp_constructor_args)
5353
return tp, problem

tests/conv_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from pytest_check import check
55

66
import numpy as np
7-
import openequivariance
87
import openequivariance as oeq
98
from openequivariance.benchmark.ConvBenchmarkSuite import load_graph
109
from itertools import product
@@ -60,7 +59,7 @@ def test_jax(self, request):
6059
def conv_object(self, request, problem, extra_conv_constructor_args, test_jax):
6160
cls = oeq.TensorProductConv
6261
if test_jax:
63-
from openequivariance.impl_jax import TensorProductConv as jax_conv
62+
from openequivariance.jax import TensorProductConv as jax_conv
6463
cls = jax_conv
6564

6665
if request.param == "atomic":

0 commit comments

Comments
 (0)