Skip to content

Commit 33ed045

Browse files
committed
More diffs.
1 parent 54f2ee8 commit 33ed045

6 files changed

Lines changed: 305 additions & 95 deletions

File tree

openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import torch
22

3+
from openequivariance.core.utils import IrrepLayoutUtils
4+
35

46
class NumpyDoubleBackwardMixin:
57
"""
@@ -13,12 +15,30 @@ def double_backward_cpu(
1315
):
1416
assert self.torch_op
1517

16-
in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True)
17-
in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True)
18+
layout = self.config.layout
19+
20+
in1_kernel = IrrepLayoutUtils.transpose_irrep_layout(
21+
in1, self.config.irreps_in1, layout, "mul_ir"
22+
)
23+
in2_kernel = IrrepLayoutUtils.transpose_irrep_layout(
24+
in2, self.config.irreps_in2, layout, "mul_ir"
25+
)
26+
out_grad_kernel = IrrepLayoutUtils.transpose_irrep_layout(
27+
out_grad, self.config.irreps_out, layout, "mul_ir"
28+
)
29+
in1_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout(
30+
in1_dgrad, self.config.irreps_in1, layout, "mul_ir"
31+
)
32+
in2_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout(
33+
in2_dgrad, self.config.irreps_in2, layout, "mul_ir"
34+
)
35+
36+
in1_torch = torch.tensor(in1_kernel).to("cuda").requires_grad_(True)
37+
in2_torch = torch.tensor(in2_kernel).to("cuda").requires_grad_(True)
1838
weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True)
19-
out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True)
20-
in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda")
21-
in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda")
39+
out_grad_torch = torch.tensor(out_grad_kernel).to("cuda").requires_grad_(True)
40+
in1_dgrad_torch = torch.tensor(in1_dgrad_kernel).to("cuda")
41+
in2_dgrad_torch = torch.tensor(in2_dgrad_kernel).to("cuda")
2242
weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda")
2343
out_torch = self.forward(in1_torch, in2_torch, weights_torch)
2444

@@ -36,12 +56,22 @@ def double_backward_cpu(
3656
grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch],
3757
)
3858

39-
return (
40-
a.detach().cpu().numpy(),
41-
b.detach().cpu().numpy(),
42-
c.detach().cpu().numpy(),
43-
d.detach().cpu().numpy(),
59+
a_np = a.detach().cpu().numpy()
60+
b_np = b.detach().cpu().numpy()
61+
c_np = c.detach().cpu().numpy()
62+
d_np = d.detach().cpu().numpy()
63+
64+
a_np = IrrepLayoutUtils.transpose_irrep_layout(
65+
a_np, self.config.irreps_in1, "mul_ir", layout
4466
)
67+
b_np = IrrepLayoutUtils.transpose_irrep_layout(
68+
b_np, self.config.irreps_in2, "mul_ir", layout
69+
)
70+
d_np = IrrepLayoutUtils.transpose_irrep_layout(
71+
d_np, self.config.irreps_out, "mul_ir", layout
72+
)
73+
74+
return (a_np, b_np, c_np, d_np)
4575

4676

4777
class NumpyDoubleBackwardMixinConv:
@@ -54,12 +84,30 @@ def double_backward_cpu(
5484
):
5585
assert self.torch_op
5686

57-
in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True)
58-
in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True)
87+
layout = self.config.layout
88+
89+
in1_kernel = IrrepLayoutUtils.transpose_irrep_layout(
90+
in1, self.config.irreps_in1, layout, "mul_ir"
91+
)
92+
in2_kernel = IrrepLayoutUtils.transpose_irrep_layout(
93+
in2, self.config.irreps_in2, layout, "mul_ir"
94+
)
95+
out_grad_kernel = IrrepLayoutUtils.transpose_irrep_layout(
96+
out_grad, self.config.irreps_out, layout, "mul_ir"
97+
)
98+
in1_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout(
99+
in1_dgrad, self.config.irreps_in1, layout, "mul_ir"
100+
)
101+
in2_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout(
102+
in2_dgrad, self.config.irreps_in2, layout, "mul_ir"
103+
)
104+
105+
in1_torch = torch.tensor(in1_kernel).to("cuda").requires_grad_(True)
106+
in2_torch = torch.tensor(in2_kernel).to("cuda").requires_grad_(True)
59107
weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True)
60-
out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True)
61-
in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda")
62-
in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda")
108+
out_grad_torch = torch.tensor(out_grad_kernel).to("cuda").requires_grad_(True)
109+
in1_dgrad_torch = torch.tensor(in1_dgrad_kernel).to("cuda")
110+
in2_dgrad_torch = torch.tensor(in2_dgrad_kernel).to("cuda")
63111
weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda")
64112

65113
torch_rows = torch.tensor(graph.rows, device="cuda")
@@ -89,9 +137,19 @@ def double_backward_cpu(
89137
grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch],
90138
)
91139

92-
return (
93-
a.detach().cpu().numpy(),
94-
b.detach().cpu().numpy(),
95-
c.detach().cpu().numpy(),
96-
d.detach().cpu().numpy(),
140+
a_np = a.detach().cpu().numpy()
141+
b_np = b.detach().cpu().numpy()
142+
c_np = c.detach().cpu().numpy()
143+
d_np = d.detach().cpu().numpy()
144+
145+
a_np = IrrepLayoutUtils.transpose_irrep_layout(
146+
a_np, self.config.irreps_in1, "mul_ir", layout
147+
)
148+
b_np = IrrepLayoutUtils.transpose_irrep_layout(
149+
b_np, self.config.irreps_in2, "mul_ir", layout
97150
)
151+
d_np = IrrepLayoutUtils.transpose_irrep_layout(
152+
d_np, self.config.irreps_out, "mul_ir", layout
153+
)
154+
155+
return (a_np, b_np, c_np, d_np)

openequivariance/openequivariance/_torch/TensorProduct.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
1-
from openequivariance.core.LoopUnrollTP import LoopUnrollTP
1+
import numpy as np
2+
import torch
3+
24
from openequivariance import TPProblem
35
from openequivariance._torch import extlib
4-
import torch
5-
from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum
6-
from openequivariance.benchmark.logging_utils import getLogger
6+
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin
77
from openequivariance._torch.utils import (
8+
enum_to_torch_dtype,
89
reorder_torch,
910
string_to_tensor,
10-
enum_to_torch_dtype,
1111
)
12-
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin
13-
14-
import numpy as np
12+
from openequivariance.benchmark.logging_utils import getLogger
13+
from openequivariance.core.LoopUnrollTP import LoopUnrollTP
14+
from openequivariance.core.utils import (
15+
IrrepLayoutUtils,
16+
dtype_to_enum,
17+
torch_to_oeq_dtype,
18+
)
1519

1620
logger = getLogger()
1721

@@ -146,12 +150,24 @@ def forward_cpu(
146150
weights, not self.config.shared_weights
147151
)
148152

149-
torch_L1_in = torch.tensor(L1_in, device="cuda")
150-
torch_L2_in = torch.tensor(L2_in, device="cuda")
153+
layout = self.config.layout
154+
155+
L1_in_kernel = IrrepLayoutUtils.transpose_irrep_layout(
156+
L1_in, self.config.irreps_in1, layout, "mul_ir"
157+
)
158+
L2_in_kernel = IrrepLayoutUtils.transpose_irrep_layout(
159+
L2_in, self.config.irreps_in2, layout, "mul_ir"
160+
)
161+
162+
torch_L1_in = torch.tensor(L1_in_kernel, device="cuda")
163+
torch_L2_in = torch.tensor(L2_in_kernel, device="cuda")
151164
torch_weights = torch.tensor(weights_chunked, device="cuda")
152165
torch_L3_out = self.forward(torch_L1_in, torch_L2_in, torch_weights)
153166

154-
L3_out[:] = torch_L3_out.numpy(force=True)
167+
L3_kernel = torch_L3_out.numpy(force=True)
168+
L3_out[:] = IrrepLayoutUtils.transpose_irrep_layout(
169+
L3_kernel, self.config.irreps_out, "mul_ir", layout
170+
)
155171

156172
def backward_cpu(
157173
self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad
@@ -160,18 +176,37 @@ def backward_cpu(
160176
weights, not self.config.shared_weights
161177
)
162178

163-
torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda")
164-
torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda")
179+
layout = self.config.layout
180+
181+
L1_in_kernel = IrrepLayoutUtils.transpose_irrep_layout(
182+
L1_in, self.config.irreps_in1, layout, "mul_ir"
183+
)
184+
L2_in_kernel = IrrepLayoutUtils.transpose_irrep_layout(
185+
L2_in, self.config.irreps_in2, layout, "mul_ir"
186+
)
187+
L3_grad_kernel = IrrepLayoutUtils.transpose_irrep_layout(
188+
L3_grad, self.config.irreps_out, layout, "mul_ir"
189+
)
190+
191+
torch_L1_in = torch.tensor(L1_in_kernel, requires_grad=True, device="cuda")
192+
torch_L2_in = torch.tensor(L2_in_kernel, requires_grad=True, device="cuda")
165193
torch_weights = torch.tensor(weights_chunked, requires_grad=True, device="cuda")
166194

167195
torch_out = self.forward(torch_L1_in, torch_L2_in, torch_weights)
168196

169-
torch_L3_grad_in = torch.tensor(L3_grad, device="cuda")
197+
torch_L3_grad_in = torch.tensor(L3_grad_kernel, device="cuda")
170198

171199
torch_out.backward(gradient=torch_L3_grad_in)
172200

173-
L1_grad[:] = torch_L1_in.grad.numpy(force=True)
174-
L2_grad[:] = torch_L2_in.grad.numpy(force=True)
201+
L1_grad_kernel = torch_L1_in.grad.numpy(force=True)
202+
L2_grad_kernel = torch_L2_in.grad.numpy(force=True)
203+
204+
L1_grad[:] = IrrepLayoutUtils.transpose_irrep_layout(
205+
L1_grad_kernel, self.config.irreps_in1, "mul_ir", layout
206+
)
207+
L2_grad[:] = IrrepLayoutUtils.transpose_irrep_layout(
208+
L2_grad_kernel, self.config.irreps_in2, "mul_ir", layout
209+
)
175210
weights_grad[:] = torch_weights.grad.numpy(force=True)
176211

177212
weights_grad[:] = self.reorder_weights_to_e3nn(

openequivariance/openequivariance/core/ComputationSchedule.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
import numpy as np
2-
from openequivariance.core.e3nn_lite import Irreps, TPProblem, wigner_3j
31
from itertools import accumulate
2+
3+
import numpy as np
4+
45
from openequivariance.benchmark.logging_utils import getLogger
6+
from openequivariance.core.e3nn_lite import Irreps, TPProblem, wigner_3j
57

68
logger = getLogger()
79

@@ -27,10 +29,13 @@ def __init__(self, src_irreps, src_views, idxs):
2729
src_ranges = [src_irreps.slices()[idx] for idx in self.src_dst_map]
2830
dst_ranges = [self.dst_irreps.slices()[i] for i in self.src_dst_map.values()]
2931

32+
if src_views[0].layout == "ir_mul":
33+
return
34+
35+
# Merge adjacent src and dst ranges
3036
self.original_src_ranges = src_ranges
3137
self.original_dst_ranges = dst_ranges
3238

33-
# Merge adjacent src and dst ranges
3439
self.src_ranges = []
3540
self.dst_ranges = []
3641

@@ -195,9 +200,10 @@ def __init__(self, instruction_tup, parent_idx):
195200
self.instruction_tup, self.parent_idx = instruction_tup, parent_idx
196201

197202
class ChildView:
198-
layout: str
199-
ir_mul_offset: int
200-
ir_mul_stride: int
203+
def __init__(self, layout: str, ir_mul_offset: int, ir_mul_stride: int):
204+
self.layout = layout
205+
self.ir_mul_offset = ir_mul_offset
206+
self.ir_mul_stride = ir_mul_stride
201207

202208
def __init__(self, input, mult_threshold):
203209
self.input = input
@@ -207,7 +213,7 @@ def __init__(self, input, mult_threshold):
207213
child_reps = [[], [], []]
208214

209215
self.irrep_maps = {} # Maps a (input_rep_idx #, mul_ir_idx) to a lst[ir_idx]
210-
self.irrep_views = [[], [], []] # View
216+
self.irrep_views = [[], [], []] # View
211217

212218
for input_rep_idx, input_rep in enumerate(input_reps): # Loop over L1, L2, L3
213219
for mul_ir_idx, mul_ir in enumerate(
@@ -223,19 +229,20 @@ def __init__(self, input, mult_threshold):
223229
len(child_reps[input_rep_idx]) - 1
224230
)
225231
if input.layout == "mul_ir":
226-
self.irrep_views.append(
232+
self.irrep_views[input_rep_idx].append(
227233
self.ChildView(
228-
layout="mul_ir",
229-
ir_mul_offset=-1,
230-
ir_mul_stride=-1
231-
))
234+
layout="mul_ir", ir_mul_offset=-1, ir_mul_stride=-1
235+
)
236+
)
232237
elif input.layout == "ir_mul":
233-
self.irrep_views.append(
238+
self.irrep_views[input_rep_idx].append(
234239
self.ChildView(
235240
layout="ir_mul",
236-
ir_mul_offset=input_rep.slices()[mul_ir_idx].start + mul_start,
237-
ir_mul_stride=mul_ir.mul
238-
))
241+
ir_mul_offset=input_rep.slices()[mul_ir_idx].start
242+
+ mul_start,
243+
ir_mul_stride=mul_ir.mul,
244+
)
245+
)
239246

240247
new_instructions = []
241248

@@ -564,9 +571,9 @@ def calculate_backward_smem(
564571
for i in range(len(self.segments)):
565572
L1_idxs, L2_idxs, L3_idxs, inst_idxs = self.segments[i]
566573

567-
L1Map = IrrepMapping(self.L1, L1_idxs)
568-
L2Map = IrrepMapping(self.L2, L2_idxs)
569-
L3Map = IrrepMapping(self.L3, L3_idxs)
574+
L1Map = IrrepMapping(self.L1, self.problem_splitter.irrep_views[0], L1_idxs)
575+
L2Map = IrrepMapping(self.L2, self.problem_splitter.irrep_views[1], L2_idxs)
576+
L3Map = IrrepMapping(self.L3, self.problem_splitter.irrep_views[2], L3_idxs)
570577

571578
instructions = [
572579
(

0 commit comments

Comments
 (0)