Skip to content

Commit 44af17b

Browse files
committed
Modified documentation.
1 parent 50e0fcc commit 44af17b

5 files changed

Lines changed: 60 additions & 29 deletions

File tree

docs/api.rst

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,26 @@ trying our code. OpenEquivariance cannot accelerate all tensor products; see
3131
:members:
3232
:undoc-members:
3333

34-
.. autofunction:: openequivariance.impl_torch_to_oeq_dtype
34+
.. autofunction:: openequivariance.torch_to_oeq_dtype
35+
36+
.. autofunction:: openequivariance.torch_ext_so_path
37+
38+
OpenEquivariance JAX API
39+
------------------------
40+
The JAX API consists of ``TensorProduct`` and ``TensorProductConv``
41+
classes that behave identically to their PyTorch counterparts. These classes
42+
do not conform exactly to the e3nn-jax API, but perform the same computation.
43+
44+
.. autoclass:: openequivariance.jax.TensorProduct
45+
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn
46+
:undoc-members:
47+
:exclude-members:
48+
49+
.. autoclass:: openequivariance.jax.TensorProductConv
50+
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn
51+
:undoc-members:
52+
:exclude-members:
3553

36-
.. autofunction:: openequivariance.impl_torch_ext_so_path
3754

3855
API Identical to e3nn
3956
---------------------

docs/conf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,14 @@
3232
"sphinx.ext.autodoc",
3333
]
3434

35-
sys.path.insert(0, str(Path("..").resolve()))
35+
sys.path.insert(0, str(Path("../openequivariance").resolve()))
3636

3737
autodoc_mock_imports = [
3838
"torch",
39+
"jax",
3940
"openequivariance.impl_torch.extlib",
41+
"openequivariance.impl_jax.extlib",
42+
"openequivariance_extjax",
4043
"jinja2",
4144
"numpy",
4245
]

openequivariance/openequivariance/__init__.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# ruff: noqa: F401
22
import sys
3-
import torch
3+
import os
44
import numpy as np
55

66
from pathlib import Path
@@ -13,12 +13,6 @@
1313
_MulIr,
1414
Instruction,
1515
)
16-
from openequivariance.impl_torch.TensorProduct import TensorProduct
17-
from openequivariance.impl_torch.TensorProductConv import (
18-
TensorProductConv,
19-
)
20-
from openequivariance.impl_torch.extlib import torch_ext_so_path
21-
from openequivariance.core.utils import torch_to_oeq_dtype
2216

2317
__version__ = None
2418
try:
@@ -44,20 +38,36 @@ def extension_source_path():
4438
"""
4539
return str(Path(__file__).parent / "extension")
4640

41+
TensorProduct, TensorProductConv, torch_ext_so_path, torch_to_oeq_dtype = None, None, None, None
4742

48-
torch.serialization.add_safe_globals(
49-
[
50-
TensorProduct,
51-
TensorProductConv,
52-
TPProblem,
53-
Irrep,
54-
Irreps,
55-
_MulIr,
56-
Instruction,
57-
np.float32,
58-
np.float64,
59-
]
60-
)
43+
if "OEQ_NOTORCH" not in os.environ or os.environ["OEQ_NOTORCH"] != "1":
44+
import torch
45+
from openequivariance.impl_torch.TensorProduct import TensorProduct
46+
from openequivariance.impl_torch.TensorProductConv import TensorProductConv
47+
48+
from openequivariance.impl_torch.extlib import torch_ext_so_path
49+
from openequivariance.core.utils import torch_to_oeq_dtype
50+
51+
torch.serialization.add_safe_globals(
52+
[
53+
TensorProduct,
54+
TensorProductConv,
55+
TPProblem,
56+
Irrep,
57+
Irreps,
58+
_MulIr,
59+
Instruction,
60+
np.float32,
61+
np.float64,
62+
]
63+
)
64+
65+
jax = None
66+
try:
67+
import openequivariance_extjax
68+
import openequivariance.impl_jax as jax
69+
except ImportError:
70+
pass
6171

6272
__all__ = [
6373
"TPProblem",
@@ -67,4 +77,5 @@ def extension_source_path():
6777
"torch_to_oeq_dtype",
6878
"_check_package_editable",
6979
"torch_ext_so_path",
80+
"jax"
7081
]

openequivariance/openequivariance/core/TensorProductBase.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def reorder_weights_from_e3nn(self, weights, has_batch_dim: bool = True):
4444
Reorders weights from ``e3nn`` canonical order to the order used by ``oeq``.
4545
4646
:param weights: Weights in ``e3nn`` canonical order, either an
47-
np.ndarray or a torch.Tensor. Tensor of dimensions ``[B, problem.weight_numel]``
47+
np.ndarray, torch.Tensor or JAX array. Tensor of dimensions ``[B, problem.weight_numel]``
4848
when ``has_batch_dim=True``, otherwise of dimensions ``[problem.weight_numel]``.
4949
5050
:param has_batch_dim: If ``True``, treats the first dimension of weights as a batch dimension. Default: ``True``.
@@ -57,8 +57,8 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim: bool = True):
5757
r"""
5858
Reorders weights from ``oeq`` canonical order to the order used by ``e3nn``.
5959
60-
:param weights: Weights in ``oeq`` canonical order, either an
61-
np.ndarray or a torch.Tensor. Tensor of dimensions ``[B, problem.weight_numel]``
60+
:param weights: Weights in ``oeq`` canonical order, either a
61+
np.ndarray, torch.Tensor or JAX array. Tensor of dimensions ``[B, problem.weight_numel]``
6262
when ``has_batch_dim=True``, otherwise of dimensions ``[problem.weight_numel]``.
6363
6464
:param has_batch_dim: If ``True``, treats the first dimension of wieghts as a batch dimension. Default: ``True``.

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ class TensorProduct(LoopUnrollTP):
6565
:param problem: Specification of the tensor product.
6666
"""
6767

68-
def __init__(self, config: TPProblem):
68+
def __init__(self, problem: TPProblem):
6969
dp = extlib.DeviceProp(0)
70-
super().__init__(config, dp, extlib.postprocess_kernel, torch_op=False)
70+
super().__init__(problem, dp, extlib.postprocess_kernel, torch_op=False)
7171

7272
self.attrs = {
7373
"kernel": self.jit_kernel,
@@ -78,7 +78,7 @@ def __init__(self, config: TPProblem):
7878
}
7979
hash_attributes(self.attrs)
8080

81-
self.weight_numel = config.weight_numel
81+
self.weight_numel = problem.weight_numel
8282
self.L3_dim = self.config.irreps_out.dim
8383

8484
def forward(self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray) -> jax.numpy.ndarray:

0 commit comments

Comments
 (0)