Skip to content

Commit df24066

Browse files
ir_mul layout support (#192)
1 parent 519d003 commit df24066

41 files changed

Lines changed: 1663 additions & 1126 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
## Latest Changes
22

3+
### v0.6.5 (2026-03-22)
4+
This release brings `ir_mul` layout support for
5+
OpenEquivariance. Pass the parameter
6+
`layout='ir_mul'` to any `TPProblem` instance to use
7+
a transposed layout for the input and output
8+
irreps. To transpose input and output irreps use
9+
`oeq.transpose_irreps` or `oeq.jax.transpose_irreps`;
10+
see our API page for usage details.
11+
312
### v0.6.4 (2026-03-05)
413
Bugfix: added missing MLIR lowerings for
514
a pair of JAX primitives (thanks @teddykoker!)

docs/api.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ PyTorch API
3030
:undoc-members:
3131
:exclude-members: name
3232

33+
.. autofunction:: openequivariance.transpose_irreps
34+
3335
.. autofunction:: openequivariance.torch_to_oeq_dtype
3436

3537
.. autofunction:: openequivariance.torch_ext_so_path
@@ -54,7 +56,9 @@ breaking the PyTorch version of OpenEquivariance.
5456
.. autoclass:: openequivariance.jax.TensorProductConv
5557
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn
5658
:undoc-members:
57-
:exclude-members:
59+
:exclude-members:
60+
61+
.. autofunction:: openequivariance.jax.transpose_irreps
5862

5963
Common API
6064
---------------------

docs/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
"openequivariance._torch.extlib",
3939
"openequivariance.jax.extlib",
4040
"openequivariance_extjax",
41+
"openequivariance.jax.jvp.tp_prim",
42+
"openequivariance.jax.jvp.conv_prim",
4143
"jinja2",
4244
"numpy",
4345
]

openequivariance/openequivariance/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def _check_package_editable():
3737

3838
from openequivariance._torch.TensorProduct import TensorProduct
3939
from openequivariance._torch.TensorProductConv import TensorProductConv
40+
from openequivariance._torch.utils import transpose_irreps
4041

4142
from openequivariance._torch.extlib import (
4243
torch_ext_so_path as torch_ext_so_path_internal,
@@ -111,4 +112,5 @@ def TensorProductConv(*args, **kwargs):
111112
"_check_package_editable",
112113
"torch_ext_so_path",
113114
"jax",
115+
"transpose_irreps",
114116
]

openequivariance/openequivariance/_torch/CUETensorProduct.py

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66

77
from openequivariance.core.TensorProductBase import TensorProductBase
88
from openequivariance.core.e3nn_lite import TPProblem
9-
from openequivariance.benchmark.logging_utils import getLogger
10-
from openequivariance.benchmark.tpp_creation_utils import (
9+
from openequivariance.core.logging import getLogger
10+
from openequivariance.benchmark.problems import (
1111
ChannelwiseTPP,
1212
FullyConnectedTPProblem,
1313
SingleInstruction,
1414
)
15-
from openequivariance.core.utils import count_cg_non_zero
1615

1716
os.environ["CUEQUIVARIANCE_OPS_USE_JIT"] = "1"
1817

@@ -235,57 +234,6 @@ def benchmark_backward(
235234
kernel_names=self.kernel_names,
236235
)
237236

238-
# Copied over from loop unroller to match arithmetic intensity on roofline plots
239-
def calculate_flops_forward(self, batch_size: int) -> dict:
240-
if self.is_uvw:
241-
return super().calculate_flops_forward(batch_size)
242-
else:
243-
tpp = self.config
244-
flop_count = {
245-
"CG_decomposition": 0,
246-
"linear_combination": 0,
247-
"outer_products": 0,
248-
}
249-
for ins in tpp.instructions:
250-
l1, l2, l3 = (
251-
tpp.irreps_in1[ins.i_in1].ir.l,
252-
tpp.irreps_in2[ins.i_in2].ir.l,
253-
tpp.irreps_out[ins.i_out].ir.l,
254-
)
255-
flop_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * (
256-
ins.path_shape[0] * ins.path_shape[1]
257-
)
258-
flop_count["linear_combination"] += (
259-
(2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0
260-
)
261-
262-
flop_count["CG_decomposition"] *= 3 * batch_size
263-
flop_count["linear_combination"] *= (
264-
batch_size # Weights do not require FMA here
265-
)
266-
flop_count["total"] = sum(flop_count.values())
267-
return flop_count
268-
269-
def calculate_flops_backward(self, batch_size: int) -> dict:
270-
if self.is_uvw:
271-
return super().calculate_flops_backward(batch_size)
272-
else:
273-
tpp = self.config
274-
flop_count = {"backward": 0}
275-
for ins in tpp.instructions:
276-
l1, l2, l3 = (
277-
tpp.irreps_in1[ins.i_in1].ir.l,
278-
tpp.irreps_in2[ins.i_in2].ir.l,
279-
tpp.irreps_out[ins.i_out].ir.l,
280-
)
281-
flop_count["backward"] += count_cg_non_zero(l1, l2, l3) * (
282-
ins.path_shape[0] * ins.path_shape[1]
283-
)
284-
285-
flop_count["backward"] *= 9 * batch_size
286-
flop_count["total"] = sum(flop_count.values())
287-
return flop_count
288-
289237
@staticmethod
290238
def name():
291239
return "CUETensorProduct"

openequivariance/openequivariance/_torch/E3NNTensorProduct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from openequivariance.core.TensorProductBase import TensorProductBase
1313
from openequivariance.core.e3nn_lite import TPProblem
14-
from openequivariance.benchmark.logging_utils import getLogger
14+
from openequivariance.core.logging import getLogger
1515
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin
1616

1717
TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning")

openequivariance/openequivariance/_torch/TensorProduct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from openequivariance._torch import extlib
44
import torch
55
from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum
6-
from openequivariance.benchmark.logging_utils import getLogger
6+
from openequivariance.core.logging import getLogger
77
from openequivariance._torch.utils import (
88
reorder_torch,
99
string_to_tensor,

openequivariance/openequivariance/_torch/TensorProductConv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
enum_to_torch_dtype,
2424
)
2525

26-
from openequivariance.benchmark.logging_utils import getLogger
26+
from openequivariance.core.logging import getLogger
2727
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv
2828

2929
logger = getLogger()

openequivariance/openequivariance/_torch/extlib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch
1010

11-
from openequivariance.benchmark.logging_utils import getLogger
11+
from openequivariance.core.logging import getLogger
1212

1313
oeq_root = str(Path(__file__).parent.parent.parent)
1414

openequivariance/openequivariance/_torch/utils.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
from types import MappingProxyType
44
from openequivariance.core.utils import DTypeEnum
5+
from openequivariance.core.e3nn_lite import Irreps
56

67

78
def reorder_helper(schedule, weights_in, direction, has_batch_dim):
@@ -75,3 +76,72 @@ def string_to_tensor(text: str) -> torch.Tensor:
7576
result = torch.tensor(np_bytes, device="cpu")
7677
result.requires_grad = False
7778
return result
79+
80+
81+
def transpose_irreps(
82+
array: torch.Tensor,
83+
irreps: Irreps,
84+
src_layout: str,
85+
dst_layout: str,
86+
) -> torch.Tensor:
87+
r"""
88+
Transpose irrep-packed feature tensors between ``mul_ir`` and ``ir_mul`` layouts.
89+
90+
The function operates on the trailing feature dimension and preserves all leading
91+
batch dimensions. It uses only differentiable PyTorch tensor operations, so gradients
92+
propagate through the transpose.
93+
94+
:param array: Input feature tensor with shape ``[..., irreps.dim]``.
95+
:param irreps: Irreps specification describing how the trailing feature dimension
96+
is partitioned into irrep blocks.
97+
:param src_layout: Source layout. Must be either ``"mul_ir"`` or ``"ir_mul"``.
98+
:param dst_layout: Destination layout. Must be either ``"mul_ir"`` or ``"ir_mul"``.
99+
100+
101+
:returns: Tensor in ``dst_layout`` with the same shape, dtype, and device as ``array``.
102+
If ``src_layout == dst_layout``, returns a clone of ``array``.
103+
104+
105+
:raises TypeError: If ``array`` is not a ``torch.Tensor``.
106+
:raises ValueError: If ``src_layout`` or ``dst_layout`` is not one of
107+
``"mul_ir"`` or ``"ir_mul"``.
108+
"""
109+
if src_layout not in ("mul_ir", "ir_mul"):
110+
raise ValueError(f"Unsupported src_layout: {src_layout}")
111+
if dst_layout not in ("mul_ir", "ir_mul"):
112+
raise ValueError(f"Unsupported dst_layout: {dst_layout}")
113+
114+
if not isinstance(array, torch.Tensor):
115+
raise TypeError(f"Expected torch.Tensor, got {type(array)}")
116+
117+
out = torch.empty_like(array)
118+
119+
if src_layout == dst_layout:
120+
out.copy_(array)
121+
return out
122+
123+
slices = irreps.slices()
124+
for ir_idx, mul_ir in enumerate(irreps):
125+
mul = mul_ir.mul
126+
dim = mul_ir.ir.dim
127+
seg = slices[ir_idx]
128+
block = array[..., seg.start : seg.stop]
129+
130+
if src_layout == "ir_mul" and dst_layout == "mul_ir":
131+
out[..., seg.start : seg.stop] = (
132+
block.reshape(*block.shape[:-1], dim, mul)
133+
.transpose(-1, -2)
134+
.reshape(*block.shape[:-1], mul * dim)
135+
)
136+
elif src_layout == "mul_ir" and dst_layout == "ir_mul":
137+
out[..., seg.start : seg.stop] = (
138+
block.reshape(*block.shape[:-1], mul, dim)
139+
.transpose(-1, -2)
140+
.reshape(*block.shape[:-1], dim * mul)
141+
)
142+
else:
143+
raise ValueError(
144+
f"Unsupported layout transpose: {src_layout} -> {dst_layout}"
145+
)
146+
147+
return out

0 commit comments

Comments
 (0)