Skip to content

Commit 22c8bb9

Browse files
Exposes / Documents Differentiable Weight Reordering Functions + Removes AOTI Warnings (#145)
* Eliminated AOTI warnings. * Updates to expose torch weight reordering. * Fixed a few more bugs. * Updated docs. * Minor fixes. * Linted.
1 parent 45e2cd8 commit 22c8bb9

9 files changed

Lines changed: 138 additions & 112 deletions

File tree

docs/api.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ trying our code. OpenEquivariance cannot accelerate all tensor products; see
1818
:doc:`this page </supported_ops>` for a list of supported configurations.
1919

2020
.. autoclass:: openequivariance.TensorProduct
21-
:members:
21+
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn, to
2222
:undoc-members:
2323
:exclude-members: name
2424

2525
.. autoclass:: openequivariance.TensorProductConv
26-
:members:
26+
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn, to
2727
:undoc-members:
2828
:exclude-members: name
2929

openequivariance/benchmark/correctness_utils.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,9 @@ def correctness_double_backward(
214214
if impl == CUETensorProduct and problem.shared_weights:
215215
weights = weights[np.newaxis, :]
216216

217-
weights_reordered = np.zeros_like(weights)
218-
if tp.reorder_weights_e3nn_to_oeq is not None:
219-
tp.reorder_weights_e3nn_to_oeq(
220-
weights, weights_reordered, not tp.config.shared_weights
221-
)
222-
else:
223-
weights_reordered = weights
217+
weights_reordered = tp.reorder_weights_from_e3nn(
218+
weights, not tp.config.shared_weights
219+
)
224220

225221
in1_torch = torch.tensor(in1, device="cuda", requires_grad=True)
226222
in2_torch = torch.tensor(in2, device="cuda", requires_grad=True)
@@ -248,11 +244,9 @@ def correctness_double_backward(
248244
)
249245

250246
weights_grad = weights_torch.grad.detach().cpu().numpy()
251-
if tp.reorder_weights_oeq_to_e3nn is not None:
252-
weights_grad_copy = weights_grad.copy()
253-
tp.reorder_weights_oeq_to_e3nn(
254-
weights_grad_copy, weights_grad, not tp.config.shared_weights
255-
)
247+
weights_grad = tp.reorder_weights_to_e3nn(
248+
weights_grad, not tp.config.shared_weights
249+
)
256250

257251
tensors.append(
258252
(

openequivariance/implementations/ComputationSchedule.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ def calculate_backward_smem(
620620
smem=self.memory_per_warp * warps_per_block,
621621
)
622622

623-
def reorder_weights(self, weights_in, weights_out, direction, has_batch_dim):
623+
def reorder_weights(self, weights_in, direction, has_batch_dim):
624624
"""
625625
Reorders weights from the canonical e3nn form to the
626626
form that LoopUnrollTP can ingest. Can also reorder the parameters
@@ -629,7 +629,9 @@ def reorder_weights(self, weights_in, weights_out, direction, has_batch_dim):
629629
If has_batch_dim is true, the first dimension of the input weight matrix
630630
is treated as the batch dimension.
631631
"""
632-
weights_out *= 0.0
632+
import torch # TODO-someday: no need to specialize this to PyTorch
633+
634+
weights_out = torch.zeros_like(weights_in)
633635
assert direction in ["forward", "backward"]
634636
for i, child_inst in enumerate(self.problem_splitter.new_instructions):
635637
parent_start, parent_end = (
@@ -670,16 +672,41 @@ def reorder_weights(self, weights_in, weights_out, direction, has_batch_dim):
670672
sliced_weights = weights_in[tuple(parent_range)].reshape(parent_shape)[
671673
tuple(weights_subrange)
672674
]
673-
weights_out[tuple(child_range)] = sliced_weights.transpose(
675+
weights_out[tuple(child_range)] = sliced_weights.permute(
674676
transpose_perm
675677
).reshape(reshape_size)
676678
elif direction == "backward":
677679
transpose_child_shape = [child_shape[i] for i in transpose_perm]
678680
sliced_weights = (
679681
weights_in[tuple(child_range)]
680682
.reshape(transpose_child_shape)
681-
.transpose(transpose_perm)
683+
.permute(transpose_perm)
682684
)
683685
weights_out[tuple(parent_range)].reshape(parent_shape)[
684686
tuple(weights_subrange)
685687
] = sliced_weights.flatten().reshape(child_shape)
688+
689+
return weights_out
690+
691+
def reorder_weights_numpy(self, weights_in, direction, has_batch_dim):
692+
import torch
693+
694+
weights_in = torch.from_numpy(weights_in.copy())
695+
result = self.reorder_weights(weights_in, direction, has_batch_dim)
696+
return result.detach().cpu().numpy().copy()
697+
698+
def reorder_weights_from_e3nn(self, weights_in, has_batch_dim):
699+
import torch
700+
701+
if isinstance(weights_in, np.ndarray):
702+
return self.reorder_weights_numpy(weights_in, "forward", has_batch_dim)
703+
elif isinstance(weights_in, torch.Tensor):
704+
return self.reorder_weights(weights_in, "forward", has_batch_dim)
705+
706+
def reorder_weights_to_e3nn(self, weights_in, has_batch_dim):
707+
import torch
708+
709+
if isinstance(weights_in, np.ndarray):
710+
return self.reorder_weights_numpy(weights_in, "backward", has_batch_dim)
711+
elif isinstance(weights_in, torch.Tensor):
712+
return self.reorder_weights(weights_in, "backward", has_batch_dim)

openequivariance/implementations/LoopUnrollTP.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -120,19 +120,13 @@ def generate_double_backward_schedule(warps_per_block):
120120
},
121121
)
122122
logger.info("Kernel compiled!")
123-
124123
logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB")
125124

126-
self.reorder_weights_e3nn_to_oeq = (
127-
lambda input, output, has_batch_dim: self.forward_schedule.reorder_weights(
128-
input, output, "forward", has_batch_dim
129-
)
130-
)
131-
self.reorder_weights_oeq_to_e3nn = (
132-
lambda input, output, has_batch_dim: self.forward_schedule.reorder_weights(
133-
input, output, "backward", has_batch_dim
134-
)
135-
)
125+
def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
126+
return self.forward_schedule.reorder_weights_from_e3nn(weights, has_batch_dim)
127+
128+
def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
129+
return self.forward_schedule.reorder_weights_to_e3nn(weights, has_batch_dim)
136130

137131
@classmethod
138132
def register_torch_fakes(cls):
@@ -177,24 +171,15 @@ def __setstate__(self, state):
177171
self.dbl_bwd_config = state["dbl_bwd_config"]
178172
self.kernel_dims = state["kernel_dims"]
179173

180-
def exec_tensor_product_rawptr(
181-
self, batch: int, L1_in: int, L2_in: int, L3_out: int, weights: int
182-
) -> None:
174+
def exec_tensor_product_rawptr(*args, **kwargs):
183175
pass
184176

185-
def backward_rawptr(
186-
self,
187-
batch_size: int,
188-
L1_in: int,
189-
L1_grad: int,
190-
L2_in: int,
191-
L2_grad: int,
192-
weights: int,
193-
weights_grad: int,
194-
L3_grad: int,
195-
):
177+
def backward_rawptr(*args, **kwargs):
196178
pass
197179

180+
def get_L3_dim(self):
181+
return self.kernel_dims["L3_dim"]
182+
198183
@torch.library.register_fake("libtorch_tp_jit::jit_tp_forward")
199184
def fake_forward(jit, L1_in, L2_in, W):
200185
L3_dim = None

openequivariance/implementations/TensorProduct.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def _init_class(self):
3838
self.forward = self.forward_opaque
3939

4040
def to(self, *args, **kwargs):
41+
r"""
42+
See `torch.nn.Module.to() <https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.to>`_.
43+
"""
4144
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
4245
*args, **kwargs
4346
)

openequivariance/implementations/TensorProductBase.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def __init__(self, config: TPProblem, torch_op: bool = False):
3232
config.irreps_out,
3333
)
3434
self.irrep_dtype, self.weight_dtype = config.irrep_dtype, config.weight_dtype
35-
self.reorder_weights_e3nn_to_oeq, self.reorder_weights_oeq_to_e3nn = None, None
3635

3736
self.tp_id = TensorProductBase.next_tp_id
3837
TensorProductBase.next_tp_id += 1
@@ -44,6 +43,34 @@ def __init__(self, config: TPProblem, torch_op: bool = False):
4443
def __call__(self, L1_in, L2_in, weights):
4544
return self.forward(L1_in, L2_in, weights)
4645

46+
def reorder_weights_from_e3nn(self, weights, has_batch_dim: bool = True):
47+
r"""
48+
Reorders weights from ``e3nn`` canonical order to the order used by ``oeq``.
49+
50+
:param weights: Weights in ``e3nn`` canonical order, either an
51+
np.ndarray or a torch.Tensor. Tensor of dimensions ``[B, problem.weight_numel]``
52+
when ``has_batch_dim=True``, otherwise of dimensions ``[problem.weight_numel]``.
53+
54+
:param has_batch_dim: If ``True``, treats the first dimension of weights as a batch dimension. Default: ``True``.
55+
56+
:return: Weights in ``oeq`` order. Output type is identical to input.
57+
"""
58+
return weights
59+
60+
def reorder_weights_to_e3nn(self, weights, has_batch_dim: bool = True):
61+
r"""
62+
Reorders weights from ``oeq`` canonical order to the order used by ``e3nn``.
63+
64+
:param weights: Weights in ``oeq`` canonical order, either an
65+
np.ndarray or a torch.Tensor. Tensor of dimensions ``[B, problem.weight_numel]``
66+
when ``has_batch_dim=True``, otherwise of dimensions ``[problem.weight_numel]``.
67+
68+
:param has_batch_dim: If ``True``, treats the first dimension of wieghts as a batch dimension. Default: ``True``.
69+
70+
:return: Weights in ``e3nn`` order. Output type is identical to input.
71+
"""
72+
return weights
73+
4774
def forward_raw(
4875
self,
4976
batch: np.uint64,
@@ -76,13 +103,9 @@ def forward_cpu(
76103
L3_out: np.ndarray,
77104
weights: np.ndarray,
78105
) -> None:
79-
weights_chunked = np.zeros_like(weights)
80-
if self.reorder_weights_e3nn_to_oeq is not None:
81-
self.reorder_weights_e3nn_to_oeq(
82-
weights, weights_chunked, not self.config.shared_weights
83-
)
84-
else:
85-
weights_chunked = weights
106+
weights_chunked = self.reorder_weights_from_e3nn(
107+
weights, not self.config.shared_weights
108+
)
86109

87110
batch = L1_in.shape[0]
88111
L1_d = DeviceBuffer(L1_in)
@@ -101,13 +124,9 @@ def forward_cpu(
101124
def backward_cpu(
102125
self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad
103126
) -> None:
104-
weights_chunked = np.zeros_like(weights)
105-
if self.reorder_weights_e3nn_to_oeq is not None:
106-
self.reorder_weights_e3nn_to_oeq(
107-
weights, weights_chunked, not self.config.shared_weights
108-
)
109-
else:
110-
weights_chunked = weights
127+
weights_chunked = self.reorder_weights_from_e3nn(
128+
weights, not self.config.shared_weights
129+
)
111130

112131
batch = L1_in.shape[0]
113132
L1_d, L2_d, L3_d = (
@@ -136,11 +155,9 @@ def backward_cpu(
136155
L2_grad_d.copy_to_host()
137156
weights_grad_d.copy_to_host()
138157

139-
if self.reorder_weights_oeq_to_e3nn is not None:
140-
weights_grad_copy = weights_grad.copy()
141-
self.reorder_weights_oeq_to_e3nn(
142-
weights_grad_copy, weights_grad, not self.config.shared_weights
143-
)
158+
weights_grad[:] = self.reorder_weights_to_e3nn(
159+
weights_grad, not self.config.shared_weights
160+
)
144161

145162
def benchmark_forward(
146163
self,

openequivariance/implementations/convolution/ConvolutionBase.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,18 @@ def __init__(
117117
self.workspace_ptr = 0
118118
self.workspace_size = 0
119119

120+
def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
121+
r"""
122+
See :py:func:`oeq.TensorProduct.reorder_weights_from_e3nn`.
123+
"""
124+
return weights
125+
126+
def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
127+
r"""
128+
See :py:func:`oeq.TensorProduct.reorder_weights_to_e3nn`.
129+
"""
130+
return weights
131+
120132
def allocate_workspace(self, size_bytes):
121133
self.workspace_size = size_bytes
122134
if self.torch_op:
@@ -136,13 +148,9 @@ def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph):
136148
assert graph.rows.dtype == self.idx_dtype
137149
assert graph.cols.dtype == self.idx_dtype
138150

139-
weights_chunked = np.zeros_like(weights)
140-
if self.reorder_weights_e3nn_to_oeq is not None:
141-
self.reorder_weights_e3nn_to_oeq(
142-
weights, weights_chunked, not self.config.shared_weights
143-
)
144-
else:
145-
weights_chunked = weights
151+
weights_chunked = self.reorder_weights_from_e3nn(
152+
weights, not self.config.shared_weights
153+
)
146154

147155
L1_d, L2_d, weights_d = (
148156
DeviceBuffer(L1_in),
@@ -174,13 +182,9 @@ def backward_cpu(
174182
assert graph.rows.dtype == self.idx_dtype
175183
assert graph.cols.dtype == self.idx_dtype
176184

177-
weights_chunked = np.zeros_like(weights)
178-
if self.reorder_weights_e3nn_to_oeq is not None:
179-
self.reorder_weights_e3nn_to_oeq(
180-
weights, weights_chunked, not self.config.shared_weights
181-
)
182-
else:
183-
weights_chunked = weights
185+
weights_chunked = self.reorder_weights_from_e3nn(
186+
weights, not self.config.shared_weights
187+
)
184188

185189
L1_d = DeviceBuffer(L1_in)
186190
L2_d = DeviceBuffer(L2_in)
@@ -219,11 +223,9 @@ def backward_cpu(
219223
L2_grad_d.copy_to_host()
220224
weights_grad_d.copy_to_host()
221225

222-
if self.reorder_weights_oeq_to_e3nn is not None:
223-
weights_grad_copy = weights_grad.copy()
224-
self.reorder_weights_oeq_to_e3nn(
225-
weights_grad_copy, weights_grad, not self.config.shared_weights
226-
)
226+
weights_grad[:] = self.reorder_weights_to_e3nn(
227+
weights_grad, not self.config.shared_weights
228+
)
227229

228230
return L1_grad, L2_grad, weights_grad
229231

@@ -712,17 +714,10 @@ def test_correctness_double_backward(
712714
in1_torch = torch.tensor(in1, device="cuda", requires_grad=True)
713715
in2_torch = torch.tensor(in2, device="cuda", requires_grad=True)
714716

715-
weights_reordered = np.zeros_like(weights)
716-
if (
717-
i == 0
718-
and hasattr(self, "reorder_weights_e3nn_to_oeq")
719-
and self.reorder_weights_e3nn_to_oeq is not None
720-
):
721-
self.reorder_weights_e3nn_to_oeq(
722-
weights, weights_reordered, not self.config.shared_weights
723-
)
724-
else:
725-
weights_reordered[:] = weights
717+
weights_reordered = tp.reorder_weights_from_e3nn(
718+
weights, not self.config.shared_weights
719+
)
720+
726721
weights_torch = torch.tensor(
727722
weights_reordered, device="cuda", requires_grad=True
728723
)
@@ -754,15 +749,9 @@ def test_correctness_double_backward(
754749
)
755750

756751
weights_grad = weights_torch.grad.detach().cpu().numpy()
757-
if (
758-
i == 0
759-
and hasattr(self, "reorder_weights_e3nn_to_oeq")
760-
and self.reorder_weights_oeq_to_e3nn is not None
761-
):
762-
weights_grad_copy = weights_grad.copy()
763-
self.reorder_weights_oeq_to_e3nn(
764-
weights_grad_copy, weights_grad, not self.config.shared_weights
765-
)
752+
weights_grad = tp.reorder_weights_to_e3nn(
753+
weights_grad, not self.config.shared_weights
754+
)
766755

767756
tensors.append(
768757
(

0 commit comments

Comments
 (0)