@@ -92,6 +92,18 @@ def backward_autograd(
9292
9393
9494class TensorProductConv (LoopUnrollConv ):
95+ r"""
96+ Identical to ``oeq.torch.TensorProductConv`` with functionality in JAX, with one
97+ key difference: integer arrays passed to this function must have dtype
98+ ``np.int32`` (as opposed to ``np.int64`` in the PyTorch version).
99+
100+ :param problem: Specification of the tensor product.
101+ :param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic
102+ fixup-based algorithm. `Default`: ``False``.
103+ :param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option,
104+ the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``.
105+ """
106+
95107 def __init__ (
96108 self , config : TPProblem , deterministic : bool = False , kahan : bool = False
97109 ):
@@ -132,7 +144,27 @@ def forward(
132144 rows : jax .numpy .ndarray ,
133145 cols : jax .numpy .ndarray ,
134146 sender_perm : Optional [jax .numpy .ndarray ] = None ,
135- ) -> jax .numpy .ndarray :
147+ ) -> jax .numpy .ndarray :
148+ r"""
149+ Computes the fused CG tensor product + convolution.
150+
151+ :param X: Tensor of shape ``[|V|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
152+ :param Y: Tensor of shape ``[|E|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
153+ :param W: Tensor of datatype ``problem.weight_dtype`` and shape
154+
155+ * ``[|E|, problem.weight_numel]`` if ``problem.shared_weights=False``
156+ * ``[problem.weight_numel]`` if ``problem.shared_weights=True``
157+
158+ :param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix,
159+ datatype ``np.int32``. Must be row-major sorted along with ``cols`` when ``deterministic=True``.
160+ :param cols: Tensor of shape ``[|E|]`` with column indices for each nonzero in the adjacency matrix,
161+ datatype ``np.int32``.
162+ :param sender_perm: Tensor of shape ``[|E|]`` and ``np.int32`` datatype containing a
163+ permutation that transposes the adjacency matrix nonzeros from row-major to column-major order.
164+ Must be provided when ``deterministic=True``.
165+
166+ :return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``.
167+ """
136168 if not self .deterministic :
137169 sender_perm = self .dummy_transpose_perm
138170 else :
0 commit comments