Skip to content

Commit 50e0fcc

Browse files
committed
Updated documentation.
1 parent ab83aef commit 50e0fcc

4 files changed

Lines changed: 41 additions & 4 deletions

File tree

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ For detailed instructions on tests, benchmarks, MACE / Nequip, and our API,
3030
check out the [documentation](https://passionlab.github.io/OpenEquivariance).
3131

3232
⭐️ **JAX Support**: Our latest update brings
33-
support for JAX. You need to execute the following
34-
commands in order:
33+
support for JAX. To install, execute the following commands in order:
3534

3635
```
3736
pip install openequivariance[jax]

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ):
5959

6060

6161
class TensorProduct(LoopUnrollTP):
62+
r"""
63+
Identical to ``oeq.torch.TensorProduct`` with functionality in JAX.
64+
65+
:param problem: Specification of the tensor product.
66+
"""
67+
6268
def __init__(self, config: TPProblem):
6369
dp = extlib.DeviceProp(0)
6470
super().__init__(config, dp, extlib.postprocess_kernel, torch_op=False)

openequivariance/openequivariance/impl_jax/TensorProductConv.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,18 @@ def backward_autograd(
9292

9393

9494
class TensorProductConv(LoopUnrollConv):
95+
r"""
96+
Identical to ``oeq.torch.TensorProductConv`` with functionality in JAX, with one
97+
key difference: integer arrays passed to this function must have dtype
98+
``np.int32`` (as opposed to ``np.int64`` in the PyTorch version).
99+
100+
:param problem: Specification of the tensor product.
101+
:param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic
102+
fixup-based algorithm. `Default`: ``False``.
103+
:param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option,
104+
the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``.
105+
"""
106+
95107
def __init__(
96108
self, config: TPProblem, deterministic: bool = False, kahan: bool = False
97109
):
@@ -132,7 +144,27 @@ def forward(
132144
rows: jax.numpy.ndarray,
133145
cols: jax.numpy.ndarray,
134146
sender_perm: Optional[jax.numpy.ndarray] = None,
135-
) -> jax.numpy.ndarray:
147+
) -> jax.numpy.ndarray:
148+
r"""
149+
Computes the fused CG tensor product + convolution.
150+
151+
:param X: Tensor of shape ``[|V|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
152+
:param Y: Tensor of shape ``[|E|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
153+
:param W: Tensor of datatype ``problem.weight_dtype`` and shape
154+
155+
* ``[|E|, problem.weight_numel]`` if ``problem.shared_weights=False``
156+
* ``[problem.weight_numel]`` if ``problem.shared_weights=True``
157+
158+
:param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix,
159+
datatype ``np.int32``. Must be row-major sorted along with ``cols`` when ``deterministic=True``.
160+
:param cols: Tensor of shape ``[|E|]`` with column indices for each nonzero in the adjacency matrix,
161+
datatype ``np.int32``.
162+
:param sender_perm: Tensor of shape ``[|E|]`` and ``np.int32`` datatype containing a
163+
permutation that transposes the adjacency matrix nonzeros from row-major to column-major order.
164+
Must be provided when ``deterministic=True``.
165+
166+
:return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``.
167+
"""
136168
if not self.deterministic:
137169
sender_perm = self.dummy_transpose_perm
138170
else:

openequivariance_extjax/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# OpenEquivariance JAX Extension
22

3-
The JAX extension module for OpenEquivariance.
3+
The JAX extension module for OpenEquivariance.

0 commit comments

Comments
 (0)