Skip to content

Commit 4947781

Browse files
committed
compaction.
1 parent 44ecf0e commit 4947781

1 file changed

Lines changed: 111 additions & 140 deletions

File tree

openequivariance/openequivariance/benchmark/correctness_utils.py

Lines changed: 111 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -79,47 +79,39 @@ def correctness_forward(
7979
reference_implementation = E3NNTensorProduct
8080

8181
result = {"thresh": correctness_threshold, "batch_size": batch_size}
82-
8382
in1, in2, weights, out = get_random_buffers_forward(problem, batch_size, prng_seed)
83+
outputs = []
8484

85-
# run reference (always in mul_ir)
86-
ref_tp = reference_implementation(problem)
87-
88-
ref_out = out.copy()
89-
ref_tp.forward_cpu(
90-
L1_in=in1.copy(), L2_in=in2.copy(), L3_out=ref_out, weights=weights.copy()
91-
)
92-
93-
weights_copy = weights.copy()
94-
if problem.shared_weights and test_implementation == CUETensorProduct:
95-
weights_copy = weights[np.newaxis, :]
96-
97-
# run test (may require ir_mul conversion)
98-
test_tp = instantiate_implementation(test_implementation, problem)
99-
test_layout = getattr(test_tp.config, "layout", "mul_ir")
100-
101-
test_in1 = in1.copy()
102-
test_in2 = in2.copy()
103-
test_out = out.copy()
104-
105-
if test_layout == "ir_mul":
106-
test_in1 = IrrepLayoutUtils.transpose_irrep_layout(
107-
test_in1, problem.irreps_in1, "mul_ir", "ir_mul"
108-
)
109-
test_in2 = IrrepLayoutUtils.transpose_irrep_layout(
110-
test_in2, problem.irreps_in2, "mul_ir", "ir_mul"
111-
)
112-
113-
test_tp.forward_cpu(
114-
L1_in=test_in1, L2_in=test_in2, L3_out=test_out, weights=weights_copy
115-
)
85+
for i, impl in enumerate([test_implementation, reference_implementation]):
86+
is_test_impl = (i == 0)
87+
tp = instantiate_implementation(impl, problem)
88+
uses_cue = impl == CUETensorProduct or isinstance(tp, CUETensorProduct)
89+
run_in1, run_in2, run_weights, run_out = [ buf.copy() for buf in (in1, in2, weights, out) ]
90+
91+
if problem.shared_weights and uses_cue:
92+
run_weights = run_weights[np.newaxis, :]
93+
94+
# Transpose inputs, if necessary, for the test implementation
95+
if is_test_impl:
96+
run_in1, run_in2 = [
97+
IrrepLayoutUtils.transpose_irrep_layout(
98+
arr, irreps, "mul_ir", tp.config.layout
99+
) for arr, irreps in zip(
100+
(run_in1, run_in2),
101+
(problem.irreps_in1, problem.irreps_in2)
102+
)
103+
]
104+
105+
tp.forward_cpu(L1_in=run_in1, L2_in=run_in2, L3_out=run_out, weights=run_weights)
106+
107+
if is_test_impl:
108+
run_out = IrrepLayoutUtils.transpose_irrep_layout(
109+
run_out, problem.irreps_out, tp.config.layout, "mul_ir"
110+
)
116111

117-
if test_layout == "ir_mul":
118-
test_out = IrrepLayoutUtils.transpose_irrep_layout(
119-
test_out, problem.irreps_out, "ir_mul", "mul_ir"
120-
)
112+
outputs.append(run_out)
121113

122-
for name, to_check, ground_truth in [("output", ref_out, test_out)]:
114+
for name, to_check, ground_truth in [("output", outputs[0], outputs[1])]:
123115
result[name] = check_similiarity(
124116
name, to_check, ground_truth, correctness_threshold
125117
)
@@ -142,87 +134,72 @@ def correctness_backward(
142134

143135
result = {"thresh": correctness_threshold, "batch_size": batch_size}
144136

145-
# run reference
146137
in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = (
147138
get_random_buffers_backward(problem, batch_size, prng_seed)
148139
)
149140

150-
ref_tp = reference_implementation(problem)
151-
152-
ref_weights_grad = weights_grad.copy()
153-
ref_in1_grad = in1_grad.copy()
154-
ref_in2_grad = in2_grad.copy()
155-
156-
ref_tp.backward_cpu(
157-
L1_in=in1.copy(),
158-
L1_grad=ref_in1_grad,
159-
L2_in=in2.copy(),
160-
L2_grad=ref_in2_grad,
161-
L3_grad=out_grad.copy(),
162-
weights=weights.copy(),
163-
weights_grad=ref_weights_grad,
164-
)
165-
166-
# run test version (may require ir_mul conversion)
167-
test_weights_grad = weights_grad.copy()
168-
test_in1_grad = in1_grad.copy()
169-
test_in2_grad = in2_grad.copy()
170-
171-
weights_copy = weights.copy()
172-
173-
if problem.shared_weights and test_implementation == CUETensorProduct:
174-
weights_copy = weights[np.newaxis, :]
175-
test_weights_grad = test_weights_grad[np.newaxis, :]
176-
177-
test_tp = instantiate_implementation(test_implementation, problem)
178-
test_layout = getattr(test_tp.config, "layout", "mul_ir")
179-
180-
test_in1 = in1.copy()
181-
test_in2 = in2.copy()
182-
test_L3_grad = out_grad.copy()
141+
grads = []
142+
for i, impl in enumerate([test_implementation, reference_implementation]):
143+
is_test_impl = i == 0
144+
tp = instantiate_implementation(impl, problem)
183145

184-
if test_layout == "ir_mul":
185-
test_in1 = IrrepLayoutUtils.transpose_irrep_layout(
186-
test_in1, problem.irreps_in1, "mul_ir", "ir_mul"
187-
)
188-
test_in2 = IrrepLayoutUtils.transpose_irrep_layout(
189-
test_in2, problem.irreps_in2, "mul_ir", "ir_mul"
190-
)
191-
test_L3_grad = IrrepLayoutUtils.transpose_irrep_layout(
192-
test_L3_grad, problem.irreps_out, "mul_ir", "ir_mul"
146+
run_in1, run_in2, run_L3_grad, run_weights, run_weights_grad, run_in1_grad, run_in2_grad = [
147+
buf.copy()
148+
for buf in (in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad)
149+
]
150+
151+
uses_cue = impl == CUETensorProduct or isinstance(tp, CUETensorProduct)
152+
if problem.shared_weights and uses_cue:
153+
run_weights = run_weights[np.newaxis, :]
154+
run_weights_grad = run_weights_grad[np.newaxis, :]
155+
156+
if is_test_impl:
157+
run_in1, run_in2, run_L3_grad = [
158+
IrrepLayoutUtils.transpose_irrep_layout(
159+
arr, irreps, "mul_ir", tp.config.layout
160+
)
161+
for arr, irreps in zip(
162+
(run_in1, run_in2, run_L3_grad),
163+
(problem.irreps_in1, problem.irreps_in2, problem.irreps_out),
164+
)
165+
]
166+
167+
tp.backward_cpu(
168+
L1_in=run_in1,
169+
L1_grad=run_in1_grad,
170+
L2_in=run_in2,
171+
L2_grad=run_in2_grad,
172+
L3_grad=run_L3_grad,
173+
weights=run_weights,
174+
weights_grad=run_weights_grad,
193175
)
194176

195-
test_tp.backward_cpu(
196-
L1_in=test_in1,
197-
L1_grad=test_in1_grad,
198-
L2_in=test_in2,
199-
L2_grad=test_in2_grad,
200-
L3_grad=test_L3_grad,
201-
weights=weights_copy,
202-
weights_grad=test_weights_grad,
203-
)
177+
if is_test_impl:
178+
run_in1_grad, run_in2_grad = [
179+
IrrepLayoutUtils.transpose_irrep_layout(
180+
arr, irreps, tp.config.layout, "mul_ir"
181+
)
182+
for arr, irreps in zip(
183+
(run_in1_grad, run_in2_grad),
184+
(problem.irreps_in1, problem.irreps_in2),
185+
)
186+
]
204187

205-
if test_layout == "ir_mul":
206-
test_in1_grad = IrrepLayoutUtils.transpose_irrep_layout(
207-
test_in1_grad, problem.irreps_in1, "ir_mul", "mul_ir"
208-
)
209-
test_in2_grad = IrrepLayoutUtils.transpose_irrep_layout(
210-
test_in2_grad, problem.irreps_in2, "ir_mul", "mul_ir"
211-
)
188+
if problem.shared_weights:
189+
run_weights_grad = run_weights_grad.squeeze()
190+
191+
grads.append((run_weights_grad, run_in1_grad, run_in2_grad))
212192

213193
weight_threshold = (
214194
correctness_threshold * batch_size
215195
if problem.shared_weights
216196
else correctness_threshold
217197
)
218198

219-
if problem.shared_weights:
220-
test_weights_grad = test_weights_grad.squeeze()
221-
222199
for name, to_check, ground_truth, threshold in [
223-
("weight_grad", test_weights_grad, ref_weights_grad, weight_threshold),
224-
("in1_grad", test_in1_grad, ref_in1_grad, correctness_threshold),
225-
("in2_grad", test_in2_grad, ref_in2_grad, correctness_threshold),
200+
("weight_grad", grads[0][0], grads[1][0], weight_threshold),
201+
("in1_grad", grads[0][1], grads[1][1], correctness_threshold),
202+
("in2_grad", grads[0][2], grads[1][2], correctness_threshold),
226203
]:
227204
result[name] = check_similiarity(name, to_check, ground_truth, threshold)
228205

@@ -254,9 +231,8 @@ def correctness_double_backward(
254231
result = {"thresh": correctness_threshold, "batch_size": batch_size}
255232

256233
tensors = []
257-
for is_test_impl, impl in enumerate(
258-
[test_implementation, reference_implementation]
259-
):
234+
for i, impl in enumerate([test_implementation, reference_implementation]):
235+
is_test_impl = i == 0
260236
tp = instantiate_implementation(impl, problem)
261237
weights_reordered = tp.reorder_weights_from_e3nn(
262238
weights, has_batch_dim=not problem.shared_weights
@@ -268,31 +244,26 @@ def correctness_double_backward(
268244
if impl == CUETensorProduct and problem.shared_weights:
269245
weights_reordered = weights_reordered[np.newaxis, :]
270246

271-
tp_layout = getattr(tp.config, "layout", "mul_ir")
272-
apply_test_layout = is_test_impl == 0 and tp_layout == "ir_mul"
273-
274-
db_in1 = in1
275-
db_in2 = in2
276-
db_out_grad = out_grad
277-
db_in1_dgrad = in1_dgrad
278-
db_in2_dgrad = in2_dgrad
279-
280-
if apply_test_layout:
281-
db_in1 = IrrepLayoutUtils.transpose_irrep_layout(
282-
in1, problem.irreps_in1, "mul_ir", "ir_mul"
283-
)
284-
db_in2 = IrrepLayoutUtils.transpose_irrep_layout(
285-
in2, problem.irreps_in2, "mul_ir", "ir_mul"
286-
)
287-
db_out_grad = IrrepLayoutUtils.transpose_irrep_layout(
288-
out_grad, problem.irreps_out, "mul_ir", "ir_mul"
289-
)
290-
db_in1_dgrad = IrrepLayoutUtils.transpose_irrep_layout(
291-
in1_dgrad, problem.irreps_in1, "mul_ir", "ir_mul"
292-
)
293-
db_in2_dgrad = IrrepLayoutUtils.transpose_irrep_layout(
294-
in2_dgrad, problem.irreps_in2, "mul_ir", "ir_mul"
295-
)
247+
db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [
248+
buf.copy() for buf in (in1, in2, out_grad, in1_dgrad, in2_dgrad)
249+
]
250+
251+
if is_test_impl:
252+
db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [
253+
IrrepLayoutUtils.transpose_irrep_layout(
254+
arr, irreps, "mul_ir", tp.config.layout
255+
)
256+
for arr, irreps in zip(
257+
(db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad),
258+
(
259+
problem.irreps_in1,
260+
problem.irreps_in2,
261+
problem.irreps_out,
262+
problem.irreps_in1,
263+
problem.irreps_in2,
264+
),
265+
)
266+
]
296267

297268
in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(
298269
db_in1,
@@ -304,16 +275,16 @@ def correctness_double_backward(
304275
db_in2_dgrad,
305276
)
306277

307-
if apply_test_layout:
308-
out_dgrad = IrrepLayoutUtils.transpose_irrep_layout(
309-
out_dgrad, problem.irreps_out, "ir_mul", "mul_ir"
310-
)
311-
in1_grad = IrrepLayoutUtils.transpose_irrep_layout(
312-
in1_grad, problem.irreps_in1, "ir_mul", "mul_ir"
313-
)
314-
in2_grad = IrrepLayoutUtils.transpose_irrep_layout(
315-
in2_grad, problem.irreps_in2, "ir_mul", "mul_ir"
316-
)
278+
if is_test_impl:
279+
out_dgrad, in1_grad, in2_grad = [
280+
IrrepLayoutUtils.transpose_irrep_layout(
281+
arr, irreps, tp.config.layout, "mul_ir"
282+
)
283+
for arr, irreps in zip(
284+
(out_dgrad, in1_grad, in2_grad),
285+
(problem.irreps_out, problem.irreps_in1, problem.irreps_in2),
286+
)
287+
]
317288

318289
tensors.append(
319290
(

0 commit comments

Comments
 (0)