Skip to content

Commit 8caa93e

Browse files
committed
Updated documentation.
1 parent 44af17b commit 8caa93e

4 files changed

Lines changed: 61 additions & 22 deletions

File tree

docs/api.rst

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ OpenEquivariance API
88
OpenEquivariance exposes two key classes: :py:class:`openequivariance.TensorProduct`, which replaces
99
``o3.TensorProduct`` from e3nn, and :py:class:`openequivariance.TensorProductConv`, which fuses
1010
the CG tensor product with a subsequent graph convolution. Initializing either class triggers
11-
JIT compilation of a custom kernel, which can take a few seconds.
11+
JIT compilation of a custom kernel, which can take a few seconds.
1212

1313
Both classes require a configuration object specified
1414
by :py:class:`openequivariance.TPProblem`, which has a constructor
@@ -17,6 +17,9 @@ We recommend reading the `e3nn documentation <https://docs.e3nn.org/en/latest/>`
1717
trying our code. OpenEquivariance cannot accelerate all tensor products; see
1818
:doc:`this page </supported_ops>` for a list of supported configurations.
1919

20+
PyTorch API
21+
------------------------
22+
2023
.. autoclass:: openequivariance.TensorProduct
2124
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn, to
2225
:undoc-members:
@@ -27,19 +30,21 @@ trying our code. OpenEquivariance cannot accelerate all tensor products; see
2730
:undoc-members:
2831
:exclude-members: name
2932

30-
.. autoclass:: openequivariance.TPProblem
31-
:members:
32-
:undoc-members:
33-
3433
.. autofunction:: openequivariance.torch_to_oeq_dtype
3534

3635
.. autofunction:: openequivariance.torch_ext_so_path
3736

38-
OpenEquivariance JAX API
37+
JAX API
3938
------------------------
4039
The JAX API consists of ``TensorProduct`` and ``TensorProductConv``
4140
classes that behave identically to their PyTorch counterparts. These classes
42-
do not conform exactly to the e3nn-jax API, but perform the same computation.
41+
do not conform exactly to the e3nn-jax API, but perform the same computation.
42+
43+
If you plan to use ``oeq.jax`` without PyTorch installed,
44+
you need to set ``OEQ_NOTORCH=1`` in your local environment (within Python,
45+
``os.environ["OEQ_NOTORCH"] = 1``). For the moment, we require this to avoid
46+
breaking the PyTorch version of OpenEquivariance.
47+
4348

4449
.. autoclass:: openequivariance.jax.TensorProduct
4550
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn
@@ -51,6 +56,12 @@ do not conform exactly to the e3nn-jax API, but perform the same computation.
5156
:undoc-members:
5257
:exclude-members:
5358

59+
Common API
60+
---------------------
61+
62+
.. autoclass:: openequivariance.TPProblem
63+
:members:
64+
:undoc-members:
5465

5566
API Identical to e3nn
5667
---------------------

docs/installation.rst

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Installation
1+
Installation (Torch and JAX)
22
==============================
33

44
.. toctree::
@@ -8,11 +8,15 @@ Installation
88
You need the following to install OpenEquivariance:
99

1010
- A Linux system equipped with an NVIDIA / AMD graphics card.
11-
- PyTorch >= 2.4 (>= 2.8 for AOTI and export).
11+
- Either PyTorch >= 2.4 (>= 2.8 for AOTI and export), or JAX with CUDA 12 support
12+
or higher.
1213
- GCC 9+ and the CUDA / HIP toolkit. The command
1314
``c++ --version`` should return >= 9.0; see below for details on
1415
setting an alternate compiler.
1516

17+
PyTorch
18+
------------------------------------------
19+
1620
Installation is one easy command, followed by import verification:
1721

1822
.. code-block:: bash
@@ -28,11 +32,8 @@ To get the nightly build, run
2832

2933
.. code-block:: bash
3034
31-
pip install git+https://github.com/PASSIONLab/OpenEquivariance
32-
35+
pip install git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance
3336
34-
Compiling the Integrated PyTorch Extension
35-
------------------------------------------
3637
To support ``torch.compile``, ``torch.export``, and
3738
JITScript, OpenEquivariance needs to compile a C++ extension
3839
tightly integrated with PyTorch. If you see a warning that
@@ -48,13 +49,37 @@ environment variable and retry the import:
4849

4950
.. code-block:: bash
5051
51-
export CCC=/path/to/your/gcc
52+
export CC=/path/to/your/gcc
5253
export CXX=/path/to/your/g++
5354
python -c "import openequivariance"
5455
5556
These configuration steps are required only ONCE after
5657
installation (or upgrade) with pip.
5758

59+
JAX
60+
------------------------------------------
61+
JAX support is currently limited to NVIDIA GPUs. You need to execute
62+
the following two commands strictly in order:
63+
64+
.. code-block:: bash
65+
66+
pip install openequivariance[jax]
67+
pip install openequivariance_extjax --no-build-isolation
68+
69+
From there, set ``OEQ_NOTORCH=1`` to avoid a PyTorch import and test the package:
70+
71+
.. code-block:: bash
72+
73+
OEQ_NOTORCH=1
74+
python -c "import openequivariance.jax"
75+
76+
You can get the nightly build as follows:
77+
78+
.. code-block:: bash
79+
80+
pip install git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance[jax]
81+
pip install git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance_extjax
82+
5883
Configurations on Major Platforms
5984
---------------------------------
6085
OpenEquivariance has been tested on both supercomputers and lab clusters.

openequivariance/openequivariance/__init__.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,18 @@ def _check_package_editable():
3131

3232
_editable_install_output_path = Path(__file__).parent.parent.parent / "outputs"
3333

34-
3534
def extension_source_path():
3635
"""
3736
:returns: Path to the source code of the C++ extension.
3837
"""
3938
return str(Path(__file__).parent / "extension")
4039

41-
TensorProduct, TensorProductConv, torch_ext_so_path, torch_to_oeq_dtype = None, None, None, None
42-
4340
if "OEQ_NOTORCH" not in os.environ or os.environ["OEQ_NOTORCH"] != "1":
4441
import torch
4542
from openequivariance.impl_torch.TensorProduct import TensorProduct
4643
from openequivariance.impl_torch.TensorProductConv import TensorProductConv
4744

48-
from openequivariance.impl_torch.extlib import torch_ext_so_path
45+
from openequivariance.impl_torch.extlib import torch_ext_so_path as torch_ext_so_path_internal
4946
from openequivariance.core.utils import torch_to_oeq_dtype
5047

5148
torch.serialization.add_safe_globals(
@@ -62,6 +59,16 @@ def extension_source_path():
6259
]
6360
)
6461

62+
def torch_ext_so_path():
63+
"""
64+
:returns: Path to a ``.so`` file that must be linked to use OpenEquivariance
65+
from the PyTorch C++ Interface.
66+
"""
67+
try:
68+
return torch_ext_so_path_internal()
69+
except NameError:
70+
return None
71+
6572
jax = None
6673
try:
6774
import openequivariance_extjax

openequivariance/openequivariance/impl_torch/extlib/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,6 @@ def _raise_import_error_helper(import_target: str):
142142

143143

144144
def torch_ext_so_path():
145-
"""
146-
:returns: Path to a ``.so`` file that must be linked to use OpenEquivariance
147-
from the PyTorch C++ Interface.
148-
"""
149145
return torch_module.__file__
150146

151147

0 commit comments

Comments
 (0)