Skip to content

Commit d924503

Browse files
committed
Abstracted away reordering.
1 parent 1b0deb0 commit d924503

4 files changed

Lines changed: 90 additions & 65 deletions

File tree

openequivariance/openequivariance/core/ComputationSchedule.py

Lines changed: 31 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -619,40 +619,33 @@ def calculate_backward_smem(
619619
smem=self.memory_per_warp * warps_per_block,
620620
)
621621

622-
def reorder_weights(self, weights_in, direction, has_batch_dim):
622+
def weight_reordering_info(self, weights_in, has_batch_dim):
623623
"""
624-
Reorders weights from the canonical e3nn form to the
625-
form that LoopUnrollTP can ingest. Can also reorder the parameters
626-
of a dense neural network layer that produces the weight matrix.
627-
628-
If has_batch_dim is true, the first dimension of the input weight matrix
629-
is treated as the batch dimension.
624+
Calculates all shapes, slices, and permutation info to reorder
625+
weights.
630626
"""
631-
import torch # TODO-someday: no need to specialize this to PyTorch
627+
batch_dim = weights_in.shape[0]
628+
reorder_specs = []
632629

633-
weights_out = torch.zeros_like(weights_in)
634-
assert direction in ["forward", "backward"]
635630
for i, child_inst in enumerate(self.problem_splitter.new_instructions):
636631
parent_start, parent_end = (
637632
child_inst.parent_weights_start,
638633
child_inst.parent_weights_end,
639634
)
640635
parent_shape = list(child_inst.parent_weights_shape)
636+
parent_range = [slice(parent_start, parent_end)]
641637

642638
child_start, child_end, child_shape = (
643639
self.updated_config.weight_range_and_shape_for_instruction(i)
644640
)
645-
646-
parent_range, child_range = (
647-
[slice(parent_start, parent_end)],
648-
[slice(child_start, child_end)],
649-
)
641+
child_range = [slice(child_start, child_end)]
642+
650643
weights_subrange = child_inst.weights_subrange
651-
batch_dim = weights_in.shape[0]
644+
652645
reshape_size = [-1]
653646
transpose_perm = None
654-
655647
connection_mode = self.updated_config.instructions[i].connection_mode
648+
656649
if connection_mode == "uvu":
657650
transpose_perm = [1, 0]
658651
elif connection_mode == "uvw":
@@ -662,50 +655,27 @@ def reorder_weights(self, weights_in, direction, has_batch_dim):
662655
child_range = [slice(0, batch_dim)] + child_range
663656
parent_range = [slice(0, batch_dim)] + parent_range
664657
parent_shape = [batch_dim] + parent_shape
658+
665659
child_shape = [batch_dim] + list(child_shape)
666660
weights_subrange = [slice(0, batch_dim)] + child_inst.weights_subrange
667661
reshape_size = [batch_dim] + reshape_size
668-
transpose_perm = [0] + [i + 1 for i in transpose_perm]
669-
670-
if direction == "forward":
671-
sliced_weights = weights_in[tuple(parent_range)].reshape(parent_shape)[
672-
tuple(weights_subrange)
673-
]
674-
weights_out[tuple(child_range)] = sliced_weights.permute(
675-
transpose_perm
676-
).reshape(reshape_size)
677-
elif direction == "backward":
678-
transpose_child_shape = [child_shape[i] for i in transpose_perm]
679-
sliced_weights = (
680-
weights_in[tuple(child_range)]
681-
.reshape(transpose_child_shape)
682-
.permute(transpose_perm)
683-
)
684-
weights_out[tuple(parent_range)].reshape(parent_shape)[
685-
tuple(weights_subrange)
686-
] = sliced_weights.flatten().reshape(child_shape)
687-
688-
return weights_out
689-
690-
def reorder_weights_numpy(self, weights_in, direction, has_batch_dim):
691-
import torch
692-
693-
weights_in = torch.from_numpy(weights_in.copy())
694-
result = self.reorder_weights(weights_in, direction, has_batch_dim)
695-
return result.detach().cpu().numpy().copy()
696-
697-
def reorder_weights_from_e3nn(self, weights_in, has_batch_dim):
698-
import torch
699-
700-
if isinstance(weights_in, np.ndarray):
701-
return self.reorder_weights_numpy(weights_in, "forward", has_batch_dim)
702-
elif isinstance(weights_in, torch.Tensor):
703-
return self.reorder_weights(weights_in, "forward", has_batch_dim)
704-
705-
def reorder_weights_to_e3nn(self, weights_in, has_batch_dim):
706-
import torch
707-
708-
if isinstance(weights_in, np.ndarray):
709-
return self.reorder_weights_numpy(weights_in, "backward", has_batch_dim)
710-
elif isinstance(weights_in, torch.Tensor):
711-
return self.reorder_weights(weights_in, "backward", has_batch_dim)
662+
663+
if transpose_perm is not None:
664+
transpose_perm = [0] + [k + 1 for k in transpose_perm]
665+
666+
transpose_child_shape = None
667+
if transpose_perm is not None:
668+
transpose_child_shape = [child_shape[k] for k in transpose_perm]
669+
670+
reorder_specs.append({
671+
"parent_range": tuple(parent_range),
672+
"parent_shape": parent_shape,
673+
"weights_subrange": tuple(weights_subrange),
674+
"child_range": tuple(child_range),
675+
"child_shape": child_shape,
676+
"transpose_perm": transpose_perm,
677+
"reshape_size": reshape_size,
678+
"transpose_child_shape": transpose_child_shape,
679+
})
680+
681+
return reorder_specs

openequivariance/openequivariance/impl_torch/TensorProduct.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import typing
66
from openequivariance.core.utils import torch_to_oeq_dtype
77
from openequivariance.benchmark.logging_utils import getLogger
8+
from openequivariance.impl_torch.utils import reorder_torch
9+
810

911
logger = getLogger()
1012

@@ -90,10 +92,10 @@ def __setstate__(self, state):
9092
self._init_class()
9193

9294
def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
93-
return self.forward_schedule.reorder_weights_from_e3nn(weights, has_batch_dim)
95+
return reorder_torch(self.forward_schedule, weights, "forward", not self.config.shared_weights)
9496

9597
def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
96-
return self.forward_schedule.reorder_weights_to_e3nn(weights, has_batch_dim)
98+
return reorder_torch(self.forward_schedule, weights, "backward", not self.config.shared_weights)
9799

98100
def forward(
99101
self, x: torch.Tensor, y: torch.Tensor, W: torch.Tensor

openequivariance/openequivariance/impl_torch/TensorProductConv.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from openequivariance import TPProblem
2020
from openequivariance.core.utils import torch_to_oeq_dtype
2121
from openequivariance.core.dtype_enum import enum_to_torch_dtype
22+
from openequivariance.impl_torch.utils import reorder_torch
2223

2324
from openequivariance.benchmark.logging_utils import getLogger
2425

@@ -418,10 +419,10 @@ def double_backward(ctx, grad_output):
418419
)
419420

420421
def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
421-
return self.forward_schedule.reorder_weights_from_e3nn(weights, has_batch_dim)
422+
return reorder_torch(self.forward_schedule, weights, "forward", not self.config.shared_weights)
422423

423424
def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
424-
return self.forward_schedule.reorder_weights_to_e3nn(weights, has_batch_dim)
425+
return reorder_torch(self.forward_schedule, weights, "backward", not self.config.shared_weights)
425426

426427
@staticmethod
427428
def name():
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
3+
def reorder_helper(schedule, weights_in, direction, has_batch_dim):
4+
assert direction in ["forward", "backward"]
5+
6+
specs = schedule.weight_reordering_info(weights_in, has_batch_dim)
7+
weights_out = torch.zeros_like(weights_in)
8+
9+
for spec in specs:
10+
parent_range = spec["parent_range"]
11+
parent_shape = spec["parent_shape"]
12+
weights_subrange = spec["weights_subrange"]
13+
child_range = spec["child_range"]
14+
transpose_perm = spec["transpose_perm"]
15+
16+
if direction == "forward":
17+
reshape_size = spec["reshape_size"]
18+
19+
sliced_weights = weights_in[parent_range].reshape(parent_shape)[
20+
weights_subrange
21+
]
22+
23+
weights_out[child_range] = sliced_weights.permute(
24+
transpose_perm
25+
).reshape(reshape_size)
26+
27+
elif direction == "backward":
28+
transpose_child_shape = spec["transpose_child_shape"]
29+
child_shape = spec["child_shape"]
30+
31+
sliced_weights = (
32+
weights_in[child_range]
33+
.reshape(transpose_child_shape)
34+
.permute(transpose_perm)
35+
)
36+
37+
weights_out[parent_range].reshape(parent_shape)[
38+
weights_subrange
39+
] = sliced_weights.flatten().reshape(child_shape)
40+
41+
return weights_out
42+
43+
def reorder_numpy_helper(schedule, weights_in, direction, has_batch_dim):
44+
weights_in = torch.from_numpy(weights_in.copy())
45+
result = reorder_helper(schedule, weights_in, direction, has_batch_dim)
46+
return result.detach().cpu().numpy().copy()
47+
48+
def reorder_torch(schedule, weights_in, direction, has_batch_dim):
49+
if isinstance(weights_in, torch.Tensor):
50+
return reorder_helper(schedule, weights_in, direction, has_batch_dim)
51+
else:
52+
return reorder_numpy_helper(schedule, weights_in, direction, has_batch_dim)

0 commit comments

Comments
 (0)