Skip to content

Commit 1bcea33

Browse files
committed
Renaming + added JAX example.
1 parent b7af425 commit 1bcea33

30 files changed

Lines changed: 127 additions & 67 deletions

docs/api.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ PyTorch API
3030
:undoc-members:
3131
:exclude-members: name
3232

33-
.. autofunction:: openequivariance.torch_to_oeq_dtype
33+
.. autofunction:: openequivariance._torch_to_oeq_dtype
3434

35-
.. autofunction:: openequivariance.torch_ext_so_path
35+
.. autofunction:: openequivariance._torch_ext_so_path
3636

3737
JAX API
3838
------------------------

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
autodoc_mock_imports = [
3838
"torch",
3939
"jax",
40-
"openequivariance.impl_torch.extlib",
40+
"openequivariance._torch.extlib",
4141
"openequivariance.jax.extlib",
4242
"openequivariance_extjax",
4343
"jinja2",

docs/supported_ops.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ toplevel. You can use our implementation by running
117117

118118
.. code-block::
119119
120-
from openequivariance.impl_torch.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction
120+
from openequivariance._torch.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction
121121
122122
Some Github users report weak performance for the
123123
symmetric contraction backward pass; your mileage may vary.

openequivariance/openequivariance/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@ def extension_source_path():
3939

4040
if "OEQ_NOTORCH" not in os.environ or os.environ["OEQ_NOTORCH"] != "1":
4141
import torch
42-
from openequivariance.impl_torch.TensorProduct import TensorProduct
43-
from openequivariance.impl_torch.TensorProductConv import TensorProductConv
4442

45-
from openequivariance.impl_torch.extlib import torch_ext_so_path as torch_ext_so_path_internal
43+
from openequivariance._torch.TensorProduct import TensorProduct
44+
from openequivariance._torch.TensorProductConv import TensorProductConv
45+
46+
from openequivariance._torch.extlib import torch_ext_so_path as torch_ext_so_path_internal
4647
from openequivariance.core.utils import torch_to_oeq_dtype
4748

4849
torch.serialization.add_safe_globals(

openequivariance/openequivariance/impl_torch/CUEConv.py renamed to openequivariance/openequivariance/_torch/CUEConv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import itertools
33
from typing import Iterator
44

5-
from openequivariance.impl_torch.CUETensorProduct import CUETensorProduct
5+
from openequivariance._torch.CUETensorProduct import CUETensorProduct
66
from openequivariance.core.ConvolutionBase import (
77
ConvolutionBase,
88
scatter_add_wrapper,

openequivariance/openequivariance/impl_torch/CUETensorProduct.py renamed to openequivariance/openequivariance/_torch/CUETensorProduct.py

File renamed without changes.

openequivariance/openequivariance/impl_torch/E3NNConv.py renamed to openequivariance/openequivariance/_torch/E3NNConv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
ConvolutionBase,
55
scatter_add_wrapper,
66
)
7-
from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct
8-
from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv
7+
from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct
8+
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv
99

1010
class E3NNConv(ConvolutionBase, NumpyDoubleBackwardMixinConv):
1111
def __init__(self, config, *, idx_dtype=np.int64, torch_op=True):

openequivariance/openequivariance/impl_torch/E3NNTensorProduct.py renamed to openequivariance/openequivariance/_torch/E3NNTensorProduct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from openequivariance.core.TensorProductBase import TensorProductBase
1313
from openequivariance.core.e3nn_lite import TPProblem
1414
from openequivariance.benchmark.logging_utils import getLogger
15-
from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin
15+
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin
1616

1717
TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning")
1818

openequivariance/openequivariance/impl_torch/FlashTPConv.py renamed to openequivariance/openequivariance/_torch/FlashTPConv.py

File renamed without changes.

openequivariance/openequivariance/impl_torch/NPDoubleBackwardMixin.py renamed to openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py

File renamed without changes.

0 commit comments

Comments
 (0)