Skip to content

Commit b5866a2

Browse files
committed
Began adding a pair of robust transpose functions.
1 parent a5d0fe5 commit b5866a2

3 files changed

Lines changed: 147 additions & 1 deletion

File tree

openequivariance/openequivariance/_torch/utils.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
from types import MappingProxyType
44
from openequivariance.core.utils import DTypeEnum
5+
from openequivariance.core.e3nn_lite import Irreps
56

67

78
def reorder_helper(schedule, weights_in, direction, has_batch_dim):
@@ -75,3 +76,72 @@ def string_to_tensor(text: str) -> torch.Tensor:
7576
result = torch.tensor(np_bytes, device="cpu")
7677
result.requires_grad = False
7778
return result
79+
80+
81+
def transpose_irreps(
82+
array: torch.Tensor,
83+
irreps: Irreps,
84+
src_layout: str,
85+
dst_layout: str,
86+
) -> torch.Tensor:
87+
r"""
88+
Transpose irrep-packed feature tensors between ``mul_ir`` and ``ir_mul`` layouts.
89+
90+
The function operates on the trailing feature dimension and preserves all leading
91+
batch dimensions. It uses only differentiable PyTorch tensor operations, so gradients
92+
propagate through the transpose.
93+
94+
:param array: Input feature tensor with shape ``[..., irreps.dim]``.
95+
:param irreps: Irreps specification describing how the trailing feature dimension
96+
is partitioned into irrep blocks.
97+
:param src_layout: Source layout. Must be either ``"mul_ir"`` or ``"ir_mul"``.
98+
:param dst_layout: Destination layout. Must be either ``"mul_ir"`` or ``"ir_mul"``.
99+
100+
101+
:returns: Tensor in ``dst_layout`` with the same shape, dtype, and device as ``array``.
102+
If ``src_layout == dst_layout``, returns a clone of ``array``.
103+
104+
105+
:raises TypeError: If ``array`` is not a ``torch.Tensor``.
106+
:raises ValueError: If ``src_layout`` or ``dst_layout`` is not one of
107+
``"mul_ir"`` or ``"ir_mul"``.
108+
"""
109+
if src_layout not in ("mul_ir", "ir_mul"):
110+
raise ValueError(f"Unsupported src_layout: {src_layout}")
111+
if dst_layout not in ("mul_ir", "ir_mul"):
112+
raise ValueError(f"Unsupported dst_layout: {dst_layout}")
113+
114+
if not isinstance(array, torch.Tensor):
115+
raise TypeError(f"Expected torch.Tensor, got {type(array)}")
116+
117+
out = torch.empty_like(array)
118+
119+
if src_layout == dst_layout:
120+
out.copy_(array)
121+
return out
122+
123+
slices = irreps.slices()
124+
for ir_idx, mul_ir in enumerate(irreps):
125+
mul = mul_ir.mul
126+
dim = mul_ir.ir.dim
127+
seg = slices[ir_idx]
128+
block = array[..., seg.start : seg.stop]
129+
130+
if src_layout == "ir_mul" and dst_layout == "mul_ir":
131+
out[..., seg.start : seg.stop] = (
132+
block.reshape(*block.shape[:-1], dim, mul)
133+
.transpose(-1, -2)
134+
.reshape(*block.shape[:-1], mul * dim)
135+
)
136+
elif src_layout == "mul_ir" and dst_layout == "ir_mul":
137+
out[..., seg.start : seg.stop] = (
138+
block.reshape(*block.shape[:-1], mul, dim)
139+
.transpose(-1, -2)
140+
.reshape(*block.shape[:-1], dim * mul)
141+
)
142+
else:
143+
raise ValueError(
144+
f"Unsupported layout transpose: {src_layout} -> {dst_layout}"
145+
)
146+
147+
return out

openequivariance/openequivariance/core/e3nn_lite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ class TPProblem:
386386
:param internal_weights: Must be False; OpenEquivariance does not support internal weights. *Default*: False.
387387
:param irrep_normalization: One of ``["component", "norm", "none"]``. *Default*: "component".
388388
:param path_normalization: One of ``["element", "path", "none"]``. *Default*: "element".
389+
:param layout: One of ``["mul_ir", "ir_mul"]``, giving the layout of irreps for all inputs and outputs. *Default*: "mul_ir".
389390
"""
390391

391392
instructions: List[Any]
Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,81 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
from openequivariance.core.e3nn_lite import Irreps
15
from openequivariance.jax.TensorProduct import TensorProduct as TensorProduct
26
from openequivariance.jax.TensorProductConv import (
37
TensorProductConv as TensorProductConv,
48
)
59

6-
__all__ = ["TensorProduct", "TensorProductConv"]
10+
11+
def transpose_irreps(
12+
array: jax.Array,
13+
irreps: Irreps,
14+
src_layout: str,
15+
dst_layout: str,
16+
) -> jax.Array:
17+
r"""
18+
Transpose irrep-packed feature arrays between ``mul_ir`` and ``ir_mul`` layouts.
19+
20+
The function operates on the trailing feature dimension and preserves all leading
21+
batch dimensions. It uses differentiable JAX operations, so gradients propagate
22+
through the transpose.
23+
24+
:param array: Input feature array with shape ``[..., irreps.dim]``.
25+
:type array: jax.Array
26+
:param irreps: Irreps specification describing how the trailing feature dimension
27+
is partitioned into irrep blocks.
28+
:type irreps: Irreps
29+
:param src_layout: Source layout. Must be either ``"mul_ir"`` or ``"ir_mul"``.
30+
:type src_layout: str
31+
:param dst_layout: Destination layout. Must be either ``"mul_ir"`` or ``"ir_mul"``.
32+
:type dst_layout: str
33+
34+
:returns: Array in ``dst_layout`` with the same shape, dtype, and device as ``array``.
35+
If ``src_layout == dst_layout``, returns a copy of ``array``.
36+
:rtype: jax.Array
37+
38+
:raises ValueError: If ``src_layout`` or ``dst_layout`` is not one of
39+
``"mul_ir"`` or ``"ir_mul"``.
40+
"""
41+
if src_layout not in ("mul_ir", "ir_mul"):
42+
raise ValueError(f"Unsupported src_layout: {src_layout}")
43+
if dst_layout not in ("mul_ir", "ir_mul"):
44+
raise ValueError(f"Unsupported dst_layout: {dst_layout}")
45+
46+
x = jnp.asarray(array)
47+
if src_layout == dst_layout:
48+
return jnp.array(x, copy=True)
49+
50+
out = jnp.empty_like(x)
51+
slices = irreps.slices()
52+
53+
for ir_idx, mul_ir in enumerate(irreps):
54+
mul = mul_ir.mul
55+
dim = mul_ir.ir.dim
56+
seg = slices[ir_idx]
57+
block = x[..., seg.start : seg.stop]
58+
59+
if src_layout == "ir_mul" and dst_layout == "mul_ir":
60+
transposed = (
61+
block.reshape(*block.shape[:-1], dim, mul)
62+
.swapaxes(-1, -2)
63+
.reshape(*block.shape[:-1], mul * dim)
64+
)
65+
elif src_layout == "mul_ir" and dst_layout == "ir_mul":
66+
transposed = (
67+
block.reshape(*block.shape[:-1], mul, dim)
68+
.swapaxes(-1, -2)
69+
.reshape(*block.shape[:-1], dim * mul)
70+
)
71+
else:
72+
raise ValueError(
73+
f"Unsupported layout transpose: {src_layout} -> {dst_layout}"
74+
)
75+
76+
out = out.at[..., seg.start : seg.stop].set(transposed)
77+
78+
return out
79+
80+
81+
__all__ = ["TensorProduct", "TensorProductConv", "transpose_irreps"]

0 commit comments

Comments
 (0)