|
2 | 2 | import numpy as np |
3 | 3 | from types import MappingProxyType |
4 | 4 | from openequivariance.core.utils import DTypeEnum |
| 5 | +from openequivariance.core.e3nn_lite import Irreps |
5 | 6 |
|
6 | 7 |
|
7 | 8 | def reorder_helper(schedule, weights_in, direction, has_batch_dim): |
@@ -75,3 +76,72 @@ def string_to_tensor(text: str) -> torch.Tensor: |
75 | 76 | result = torch.tensor(np_bytes, device="cpu") |
76 | 77 | result.requires_grad = False |
77 | 78 | return result |
| 79 | + |
| 80 | + |
| 81 | +def transpose_irreps( |
| 82 | + array: torch.Tensor, |
| 83 | + irreps: Irreps, |
| 84 | + src_layout: str, |
| 85 | + dst_layout: str, |
| 86 | +) -> torch.Tensor: |
| 87 | + r""" |
| 88 | + Transpose irrep-packed feature tensors between ``mul_ir`` and ``ir_mul`` layouts. |
| 89 | +
|
| 90 | + The function operates on the trailing feature dimension and preserves all leading |
| 91 | + batch dimensions. It uses only differentiable PyTorch tensor operations, so gradients |
| 92 | + propagate through the transpose. |
| 93 | +
|
| 94 | + :param array: Input feature tensor with shape ``[..., irreps.dim]``. |
| 95 | + :param irreps: Irreps specification describing how the trailing feature dimension |
| 96 | + is partitioned into irrep blocks. |
| 97 | + :param src_layout: Source layout. Must be either ``"mul_ir"`` or ``"ir_mul"``. |
| 98 | + :param dst_layout: Destination layout. Must be either ``"mul_ir"`` or ``"ir_mul"``. |
| 99 | +
|
| 100 | +
|
| 101 | + :returns: Tensor in ``dst_layout`` with the same shape, dtype, and device as ``array``. |
| 102 | + If ``src_layout == dst_layout``, returns a clone of ``array``. |
| 103 | +
|
| 104 | +
|
| 105 | + :raises TypeError: If ``array`` is not a ``torch.Tensor``. |
| 106 | + :raises ValueError: If ``src_layout`` or ``dst_layout`` is not one of |
| 107 | + ``"mul_ir"`` or ``"ir_mul"``. |
| 108 | + """ |
| 109 | + if src_layout not in ("mul_ir", "ir_mul"): |
| 110 | + raise ValueError(f"Unsupported src_layout: {src_layout}") |
| 111 | + if dst_layout not in ("mul_ir", "ir_mul"): |
| 112 | + raise ValueError(f"Unsupported dst_layout: {dst_layout}") |
| 113 | + |
| 114 | + if not isinstance(array, torch.Tensor): |
| 115 | + raise TypeError(f"Expected torch.Tensor, got {type(array)}") |
| 116 | + |
| 117 | + out = torch.empty_like(array) |
| 118 | + |
| 119 | + if src_layout == dst_layout: |
| 120 | + out.copy_(array) |
| 121 | + return out |
| 122 | + |
| 123 | + slices = irreps.slices() |
| 124 | + for ir_idx, mul_ir in enumerate(irreps): |
| 125 | + mul = mul_ir.mul |
| 126 | + dim = mul_ir.ir.dim |
| 127 | + seg = slices[ir_idx] |
| 128 | + block = array[..., seg.start : seg.stop] |
| 129 | + |
| 130 | + if src_layout == "ir_mul" and dst_layout == "mul_ir": |
| 131 | + out[..., seg.start : seg.stop] = ( |
| 132 | + block.reshape(*block.shape[:-1], dim, mul) |
| 133 | + .transpose(-1, -2) |
| 134 | + .reshape(*block.shape[:-1], mul * dim) |
| 135 | + ) |
| 136 | + elif src_layout == "mul_ir" and dst_layout == "ir_mul": |
| 137 | + out[..., seg.start : seg.stop] = ( |
| 138 | + block.reshape(*block.shape[:-1], mul, dim) |
| 139 | + .transpose(-1, -2) |
| 140 | + .reshape(*block.shape[:-1], dim * mul) |
| 141 | + ) |
| 142 | + else: |
| 143 | + raise ValueError( |
| 144 | + f"Unsupported layout transpose: {src_layout} -> {dst_layout}" |
| 145 | + ) |
| 146 | + |
| 147 | + return out |
0 commit comments