Skip to content

Commit 85e988f

Browse files
committed
Made more progress.
1 parent c78d48f commit 85e988f

3 files changed

Lines changed: 123 additions & 132 deletions

File tree

Lines changed: 10 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import torch
22

3-
from openequivariance.core.utils import IrrepLayoutUtils
4-
53

64
class NumpyDoubleBackwardMixin:
75
"""
@@ -15,30 +13,12 @@ def double_backward_cpu(
1513
):
1614
assert self.torch_op
1715

18-
layout = self.config.layout
19-
20-
in1_kernel = IrrepLayoutUtils.transpose_irrep_layout(
21-
in1, self.config.irreps_in1, layout, "mul_ir"
22-
)
23-
in2_kernel = IrrepLayoutUtils.transpose_irrep_layout(
24-
in2, self.config.irreps_in2, layout, "mul_ir"
25-
)
26-
out_grad_kernel = IrrepLayoutUtils.transpose_irrep_layout(
27-
out_grad, self.config.irreps_out, layout, "mul_ir"
28-
)
29-
in1_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout(
30-
in1_dgrad, self.config.irreps_in1, layout, "mul_ir"
31-
)
32-
in2_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout(
33-
in2_dgrad, self.config.irreps_in2, layout, "mul_ir"
34-
)
35-
36-
in1_torch = torch.tensor(in1_kernel).to("cuda").requires_grad_(True)
37-
in2_torch = torch.tensor(in2_kernel).to("cuda").requires_grad_(True)
16+
in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True)
17+
in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True)
3818
weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True)
39-
out_grad_torch = torch.tensor(out_grad_kernel).to("cuda").requires_grad_(True)
40-
in1_dgrad_torch = torch.tensor(in1_dgrad_kernel).to("cuda")
41-
in2_dgrad_torch = torch.tensor(in2_dgrad_kernel).to("cuda")
19+
out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True)
20+
in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda")
21+
in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda")
4222
weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda")
4323
out_torch = self.forward(in1_torch, in2_torch, weights_torch)
4424

@@ -61,16 +41,6 @@ def double_backward_cpu(
6141
c_np = c.detach().cpu().numpy()
6242
d_np = d.detach().cpu().numpy()
6343

64-
a_np = IrrepLayoutUtils.transpose_irrep_layout(
65-
a_np, self.config.irreps_in1, "mul_ir", layout
66-
)
67-
b_np = IrrepLayoutUtils.transpose_irrep_layout(
68-
b_np, self.config.irreps_in2, "mul_ir", layout
69-
)
70-
d_np = IrrepLayoutUtils.transpose_irrep_layout(
71-
d_np, self.config.irreps_out, "mul_ir", layout
72-
)
73-
7444
return (a_np, b_np, c_np, d_np)
7545

7646

@@ -84,30 +54,12 @@ def double_backward_cpu(
8454
):
8555
assert self.torch_op
8656

87-
layout = self.config.layout
88-
89-
in1_kernel = IrrepLayoutUtils.transpose_irrep_layout(
90-
in1, self.config.irreps_in1, layout, "mul_ir"
91-
)
92-
in2_kernel = IrrepLayoutUtils.transpose_irrep_layout(
93-
in2, self.config.irreps_in2, layout, "mul_ir"
94-
)
95-
out_grad_kernel = IrrepLayoutUtils.transpose_irrep_layout(
96-
out_grad, self.config.irreps_out, layout, "mul_ir"
97-
)
98-
in1_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout(
99-
in1_dgrad, self.config.irreps_in1, layout, "mul_ir"
100-
)
101-
in2_dgrad_kernel = IrrepLayoutUtils.transpose_irrep_layout(
102-
in2_dgrad, self.config.irreps_in2, layout, "mul_ir"
103-
)
104-
105-
in1_torch = torch.tensor(in1_kernel).to("cuda").requires_grad_(True)
106-
in2_torch = torch.tensor(in2_kernel).to("cuda").requires_grad_(True)
57+
in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True)
58+
in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True)
10759
weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True)
108-
out_grad_torch = torch.tensor(out_grad_kernel).to("cuda").requires_grad_(True)
109-
in1_dgrad_torch = torch.tensor(in1_dgrad_kernel).to("cuda")
110-
in2_dgrad_torch = torch.tensor(in2_dgrad_kernel).to("cuda")
60+
out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True)
61+
in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda")
62+
in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda")
11163
weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda")
11264

11365
torch_rows = torch.tensor(graph.rows, device="cuda")
@@ -142,14 +94,4 @@ def double_backward_cpu(
14294
c_np = c.detach().cpu().numpy()
14395
d_np = d.detach().cpu().numpy()
14496

145-
a_np = IrrepLayoutUtils.transpose_irrep_layout(
146-
a_np, self.config.irreps_in1, "mul_ir", layout
147-
)
148-
b_np = IrrepLayoutUtils.transpose_irrep_layout(
149-
b_np, self.config.irreps_in2, "mul_ir", layout
150-
)
151-
d_np = IrrepLayoutUtils.transpose_irrep_layout(
152-
d_np, self.config.irreps_out, "mul_ir", layout
153-
)
154-
15597
return (a_np, b_np, c_np, d_np)

openequivariance/openequivariance/_torch/TensorProduct.py

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
)
1212
from openequivariance.benchmark.logging_utils import getLogger
1313
from openequivariance.core.LoopUnrollTP import LoopUnrollTP
14-
from openequivariance.core.utils import (
15-
IrrepLayoutUtils,
16-
dtype_to_enum,
17-
torch_to_oeq_dtype,
18-
)
14+
from openequivariance.core.utils import dtype_to_enum, torch_to_oeq_dtype
1915

2016
logger = getLogger()
2117

@@ -150,24 +146,12 @@ def forward_cpu(
150146
weights, not self.config.shared_weights
151147
)
152148

153-
layout = self.config.layout
154-
155-
L1_in_kernel = IrrepLayoutUtils.transpose_irrep_layout(
156-
L1_in, self.config.irreps_in1, layout, "mul_ir"
157-
)
158-
L2_in_kernel = IrrepLayoutUtils.transpose_irrep_layout(
159-
L2_in, self.config.irreps_in2, layout, "mul_ir"
160-
)
161-
162-
torch_L1_in = torch.tensor(L1_in_kernel, device="cuda")
163-
torch_L2_in = torch.tensor(L2_in_kernel, device="cuda")
149+
torch_L1_in = torch.tensor(L1_in, device="cuda")
150+
torch_L2_in = torch.tensor(L2_in, device="cuda")
164151
torch_weights = torch.tensor(weights_chunked, device="cuda")
165152
torch_L3_out = self.forward(torch_L1_in, torch_L2_in, torch_weights)
166153

167-
L3_kernel = torch_L3_out.numpy(force=True)
168-
L3_out[:] = IrrepLayoutUtils.transpose_irrep_layout(
169-
L3_kernel, self.config.irreps_out, "mul_ir", layout
170-
)
154+
L3_out[:] = torch_L3_out.numpy(force=True)
171155

172156
def backward_cpu(
173157
self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad
@@ -176,37 +160,18 @@ def backward_cpu(
176160
weights, not self.config.shared_weights
177161
)
178162

179-
layout = self.config.layout
180-
181-
L1_in_kernel = IrrepLayoutUtils.transpose_irrep_layout(
182-
L1_in, self.config.irreps_in1, layout, "mul_ir"
183-
)
184-
L2_in_kernel = IrrepLayoutUtils.transpose_irrep_layout(
185-
L2_in, self.config.irreps_in2, layout, "mul_ir"
186-
)
187-
L3_grad_kernel = IrrepLayoutUtils.transpose_irrep_layout(
188-
L3_grad, self.config.irreps_out, layout, "mul_ir"
189-
)
190-
191-
torch_L1_in = torch.tensor(L1_in_kernel, requires_grad=True, device="cuda")
192-
torch_L2_in = torch.tensor(L2_in_kernel, requires_grad=True, device="cuda")
163+
torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda")
164+
torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda")
193165
torch_weights = torch.tensor(weights_chunked, requires_grad=True, device="cuda")
194166

195167
torch_out = self.forward(torch_L1_in, torch_L2_in, torch_weights)
196168

197-
torch_L3_grad_in = torch.tensor(L3_grad_kernel, device="cuda")
169+
torch_L3_grad_in = torch.tensor(L3_grad, device="cuda")
198170

199171
torch_out.backward(gradient=torch_L3_grad_in)
200172

201-
L1_grad_kernel = torch_L1_in.grad.numpy(force=True)
202-
L2_grad_kernel = torch_L2_in.grad.numpy(force=True)
203-
204-
L1_grad[:] = IrrepLayoutUtils.transpose_irrep_layout(
205-
L1_grad_kernel, self.config.irreps_in1, "mul_ir", layout
206-
)
207-
L2_grad[:] = IrrepLayoutUtils.transpose_irrep_layout(
208-
L2_grad_kernel, self.config.irreps_in2, "mul_ir", layout
209-
)
173+
L1_grad[:] = torch_L1_in.grad.numpy(force=True)
174+
L2_grad[:] = torch_L2_in.grad.numpy(force=True)
210175
weights_grad[:] = torch_weights.grad.numpy(force=True)
211176

212177
weights_grad[:] = self.reorder_weights_to_e3nn(

openequivariance/openequivariance/benchmark/correctness_utils.py

Lines changed: 104 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
from typing import Optional, Union
22

3-
from openequivariance.core.TensorProductBase import TensorProductBase
4-
from openequivariance.core.e3nn_lite import TPProblem
3+
import numpy as np
4+
import numpy.linalg as la
5+
56
from openequivariance._torch.CUETensorProduct import CUETensorProduct
7+
from openequivariance.benchmark.logging_utils import bcolors, getLogger
68
from openequivariance.benchmark.random_buffer_utils import (
7-
get_random_buffers_forward,
89
get_random_buffers_backward,
910
get_random_buffers_double_backward,
11+
get_random_buffers_forward,
1012
)
11-
12-
from openequivariance.benchmark.logging_utils import getLogger, bcolors
13-
import numpy as np
14-
import numpy.linalg as la
13+
from openequivariance.core.e3nn_lite import TPProblem
14+
from openequivariance.core.TensorProductBase import TensorProductBase
15+
from openequivariance.core.utils import IrrepLayoutUtils
1516

1617
logger = getLogger()
1718

@@ -81,7 +82,7 @@ def correctness_forward(
8182

8283
in1, in2, weights, out = get_random_buffers_forward(problem, batch_size, prng_seed)
8384

84-
# run reference
85+
# run reference (always in mul_ir)
8586
ref_tp = reference_implementation(problem)
8687

8788
ref_out = out.copy()
@@ -93,13 +94,31 @@ def correctness_forward(
9394
if problem.shared_weights and test_implementation == CUETensorProduct:
9495
weights_copy = weights[np.newaxis, :]
9596

96-
# run test
97+
# run test (may require ir_mul conversion)
9798
test_tp = instantiate_implementation(test_implementation, problem)
99+
test_layout = getattr(test_tp.config, "layout", "mul_ir")
100+
101+
test_in1 = in1.copy()
102+
test_in2 = in2.copy()
98103
test_out = out.copy()
104+
105+
if test_layout == "ir_mul":
106+
test_in1 = IrrepLayoutUtils.transpose_irrep_layout(
107+
test_in1, problem.irreps_in1, "mul_ir", "ir_mul"
108+
)
109+
test_in2 = IrrepLayoutUtils.transpose_irrep_layout(
110+
test_in2, problem.irreps_in2, "mul_ir", "ir_mul"
111+
)
112+
99113
test_tp.forward_cpu(
100-
L1_in=in1.copy(), L2_in=in2.copy(), L3_out=test_out, weights=weights_copy
114+
L1_in=test_in1, L2_in=test_in2, L3_out=test_out, weights=weights_copy
101115
)
102116

117+
if test_layout == "ir_mul":
118+
test_out = IrrepLayoutUtils.transpose_irrep_layout(
119+
test_out, problem.irreps_out, "ir_mul", "mul_ir"
120+
)
121+
103122
for name, to_check, ground_truth in [("output", ref_out, test_out)]:
104123
result[name] = check_similiarity(
105124
name, to_check, ground_truth, correctness_threshold
@@ -144,7 +163,7 @@ def correctness_backward(
144163
weights_grad=ref_weights_grad,
145164
)
146165

147-
# run test version
166+
# run test version (may require ir_mul conversion)
148167
test_weights_grad = weights_grad.copy()
149168
test_in1_grad = in1_grad.copy()
150169
test_in2_grad = in2_grad.copy()
@@ -156,16 +175,41 @@ def correctness_backward(
156175
test_weights_grad = test_weights_grad[np.newaxis, :]
157176

158177
test_tp = instantiate_implementation(test_implementation, problem)
178+
test_layout = getattr(test_tp.config, "layout", "mul_ir")
179+
180+
test_in1 = in1.copy()
181+
test_in2 = in2.copy()
182+
test_L3_grad = out_grad.copy()
183+
184+
if test_layout == "ir_mul":
185+
test_in1 = IrrepLayoutUtils.transpose_irrep_layout(
186+
test_in1, problem.irreps_in1, "mul_ir", "ir_mul"
187+
)
188+
test_in2 = IrrepLayoutUtils.transpose_irrep_layout(
189+
test_in2, problem.irreps_in2, "mul_ir", "ir_mul"
190+
)
191+
test_L3_grad = IrrepLayoutUtils.transpose_irrep_layout(
192+
test_L3_grad, problem.irreps_out, "mul_ir", "ir_mul"
193+
)
194+
159195
test_tp.backward_cpu(
160-
L1_in=in1.copy(),
196+
L1_in=test_in1,
161197
L1_grad=test_in1_grad,
162-
L2_in=in2.copy(),
198+
L2_in=test_in2,
163199
L2_grad=test_in2_grad,
164-
L3_grad=out_grad.copy(),
200+
L3_grad=test_L3_grad,
165201
weights=weights_copy,
166202
weights_grad=test_weights_grad,
167203
)
168204

205+
if test_layout == "ir_mul":
206+
test_in1_grad = IrrepLayoutUtils.transpose_irrep_layout(
207+
test_in1_grad, problem.irreps_in1, "ir_mul", "mul_ir"
208+
)
209+
test_in2_grad = IrrepLayoutUtils.transpose_irrep_layout(
210+
test_in2_grad, problem.irreps_in2, "ir_mul", "mul_ir"
211+
)
212+
169213
weight_threshold = (
170214
correctness_threshold * batch_size
171215
if problem.shared_weights
@@ -210,7 +254,9 @@ def correctness_double_backward(
210254
result = {"thresh": correctness_threshold, "batch_size": batch_size}
211255

212256
tensors = []
213-
for _, impl in enumerate([test_implementation, reference_implementation]):
257+
for is_test_impl, impl in enumerate(
258+
[test_implementation, reference_implementation]
259+
):
214260
tp = instantiate_implementation(impl, problem)
215261
weights_reordered = tp.reorder_weights_from_e3nn(
216262
weights, has_batch_dim=not problem.shared_weights
@@ -222,15 +268,53 @@ def correctness_double_backward(
222268
if impl == CUETensorProduct and problem.shared_weights:
223269
weights_reordered = weights_reordered[np.newaxis, :]
224270

271+
tp_layout = getattr(tp.config, "layout", "mul_ir")
272+
apply_test_layout = is_test_impl == 0 and tp_layout == "ir_mul"
273+
274+
db_in1 = in1
275+
db_in2 = in2
276+
db_out_grad = out_grad
277+
db_in1_dgrad = in1_dgrad
278+
db_in2_dgrad = in2_dgrad
279+
280+
if apply_test_layout:
281+
db_in1 = IrrepLayoutUtils.transpose_irrep_layout(
282+
in1, problem.irreps_in1, "mul_ir", "ir_mul"
283+
)
284+
db_in2 = IrrepLayoutUtils.transpose_irrep_layout(
285+
in2, problem.irreps_in2, "mul_ir", "ir_mul"
286+
)
287+
db_out_grad = IrrepLayoutUtils.transpose_irrep_layout(
288+
out_grad, problem.irreps_out, "mul_ir", "ir_mul"
289+
)
290+
db_in1_dgrad = IrrepLayoutUtils.transpose_irrep_layout(
291+
in1_dgrad, problem.irreps_in1, "mul_ir", "ir_mul"
292+
)
293+
db_in2_dgrad = IrrepLayoutUtils.transpose_irrep_layout(
294+
in2_dgrad, problem.irreps_in2, "mul_ir", "ir_mul"
295+
)
296+
225297
in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(
226-
in1,
227-
in2,
228-
out_grad,
298+
db_in1,
299+
db_in2,
300+
db_out_grad,
229301
weights_reordered,
230302
weights_dgrad_reordered,
231-
in1_dgrad,
232-
in2_dgrad,
303+
db_in1_dgrad,
304+
db_in2_dgrad,
233305
)
306+
307+
if apply_test_layout:
308+
out_dgrad = IrrepLayoutUtils.transpose_irrep_layout(
309+
out_dgrad, problem.irreps_out, "ir_mul", "mul_ir"
310+
)
311+
in1_grad = IrrepLayoutUtils.transpose_irrep_layout(
312+
in1_grad, problem.irreps_in1, "ir_mul", "mul_ir"
313+
)
314+
in2_grad = IrrepLayoutUtils.transpose_irrep_layout(
315+
in2_grad, problem.irreps_in2, "ir_mul", "mul_ir"
316+
)
317+
234318
tensors.append(
235319
(
236320
out_dgrad,

0 commit comments

Comments
 (0)