Skip to content

Commit 3c9ed29

Browse files
committed
More compaction.
1 parent 7fd7ab6 commit 3c9ed29

2 files changed

Lines changed: 89 additions & 90 deletions

File tree

openequivariance/openequivariance/core/ConvolutionBase.py

Lines changed: 86 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,6 @@ def test_correctness_forward(
144144
check_reproducible=True,
145145
high_precision_ref=False,
146146
):
147-
def maybe_transpose_input_for_test_impl(x, irreps):
148-
if self.config.layout == "ir_mul":
149-
return IrrepLayoutUtils.transpose_irrep_layout(
150-
x, irreps, "mul_ir", "ir_mul"
151-
)
152-
return x
153-
154147
if reference_implementation is None:
155148
from openequivariance._torch.E3NNConv import E3NNConv
156149

@@ -192,23 +185,29 @@ def maybe_transpose_input_for_test_impl(x, irreps):
192185

193186
ref_out[:] = ref_tp.forward(**args).cpu().numpy()
194187

195-
test_out = out.copy()
188+
run_in1, run_in2, run_weights, test_out = [
189+
buf.copy() for buf in (in1, in2, weights, out)
190+
]
191+
run_in1, run_in2 = [
192+
IrrepLayoutUtils.transpose_irrep_layout(
193+
arr, irreps, "mul_ir", self.config.layout
194+
)
195+
for arr, irreps in zip(
196+
(run_in1, run_in2),
197+
(self.config.irreps_in1, self.config.irreps_in2),
198+
)
199+
]
196200
self.forward_cpu(
197-
L1_in=maybe_transpose_input_for_test_impl(
198-
in1.copy(), self.config.irreps_in1
199-
),
200-
L2_in=maybe_transpose_input_for_test_impl(
201-
in2.copy(), self.config.irreps_in2
202-
),
203-
weights=weights.copy(),
201+
L1_in=run_in1,
202+
L2_in=run_in2,
203+
weights=run_weights,
204204
L3_out=test_out,
205205
graph=graph,
206206
)
207207

208-
if self.config.layout == "ir_mul":
209-
test_out = IrrepLayoutUtils.transpose_irrep_layout(
210-
test_out, self.config.irreps_out, "ir_mul", "mul_ir"
211-
)
208+
test_out = IrrepLayoutUtils.transpose_irrep_layout(
209+
test_out, self.config.irreps_out, self.config.layout, "mul_ir"
210+
)
212211

213212
for name, to_check, ground_truth in [("output", ref_out, test_out)]:
214213
result[name] = check_similiarity(name, to_check, ground_truth, thresh)
@@ -221,22 +220,29 @@ def maybe_transpose_input_for_test_impl(x, irreps):
221220

222221
for i in range(num_trials):
223222
repeated_run = out.copy()
223+
rep_in1, rep_in2, rep_weights = [
224+
buf.copy() for buf in (in1, in2, weights)
225+
]
226+
rep_in1, rep_in2 = [
227+
IrrepLayoutUtils.transpose_irrep_layout(
228+
arr, irreps, "mul_ir", self.config.layout
229+
)
230+
for arr, irreps in zip(
231+
(rep_in1, rep_in2),
232+
(self.config.irreps_in1, self.config.irreps_in2),
233+
)
234+
]
224235
self.forward_cpu(
225-
L1_in=maybe_transpose_input_for_test_impl(
226-
in1.copy(), self.config.irreps_in1
227-
),
228-
L2_in=maybe_transpose_input_for_test_impl(
229-
in2.copy(), self.config.irreps_in2
230-
),
231-
weights=weights.copy(),
236+
L1_in=rep_in1,
237+
L2_in=rep_in2,
238+
weights=rep_weights,
232239
L3_out=repeated_run,
233240
graph=graph,
234241
)
235242

236-
if self.config.layout == "ir_mul":
237-
repeated_run = IrrepLayoutUtils.transpose_irrep_layout(
238-
repeated_run, self.config.irreps_out, "ir_mul", "mul_ir"
239-
)
243+
repeated_run = IrrepLayoutUtils.transpose_irrep_layout(
244+
repeated_run, self.config.irreps_out, self.config.layout, "mul_ir"
245+
)
240246

241247
for name, to_check, ground_truth in [
242248
("output", repeated_run, test_out)
@@ -413,13 +419,6 @@ def test_correctness_backward(
413419
reference_implementation=None,
414420
high_precision_ref=False,
415421
):
416-
def maybe_transpose_input_for_test_impl(x, irreps):
417-
if self.config.layout == "ir_mul":
418-
return IrrepLayoutUtils.transpose_irrep_layout(
419-
x, irreps, "mul_ir", "ir_mul"
420-
)
421-
return x
422-
423422
if reference_implementation is None:
424423
from openequivariance._torch.E3NNConv import E3NNConv
425424

@@ -469,34 +468,39 @@ def maybe_transpose_input_for_test_impl(x, irreps):
469468
test_in1_grad = in1_grad.copy()
470469
test_in2_grad = in2_grad.copy()
471470

472-
test_L3_grad = out_grad.copy()
473-
if self.config.layout == "ir_mul":
474-
test_L3_grad = IrrepLayoutUtils.transpose_irrep_layout(
475-
test_L3_grad, self.config.irreps_out, "mul_ir", "ir_mul"
471+
test_in1, test_in2, test_L3_grad = [
472+
buf.copy() for buf in (in1, in2, out_grad)
473+
]
474+
test_in1, test_in2, test_L3_grad = [
475+
IrrepLayoutUtils.transpose_irrep_layout(
476+
arr, irreps, "mul_ir", self.config.layout
477+
)
478+
for arr, irreps in zip(
479+
(test_in1, test_in2, test_L3_grad),
480+
(self.config.irreps_in1, self.config.irreps_in2, self.config.irreps_out),
476481
)
482+
]
477483

478484
self.backward_cpu(
479-
L1_in=maybe_transpose_input_for_test_impl(
480-
in1.copy(), self.config.irreps_in1
481-
),
485+
L1_in=test_in1,
482486
L1_grad=test_in1_grad,
483-
L2_in=maybe_transpose_input_for_test_impl(
484-
in2.copy(), self.config.irreps_in2
485-
),
487+
L2_in=test_in2,
486488
L2_grad=test_in2_grad,
487489
L3_grad=test_L3_grad,
488490
weights=weights.copy(),
489491
weights_grad=test_weights_grad,
490492
graph=graph,
491493
)
492494

493-
if self.config.layout == "ir_mul":
494-
test_in1_grad = IrrepLayoutUtils.transpose_irrep_layout(
495-
test_in1_grad, self.config.irreps_in1, "ir_mul", "mul_ir"
495+
test_in1_grad, test_in2_grad = [
496+
IrrepLayoutUtils.transpose_irrep_layout(
497+
arr, irreps, self.config.layout, "mul_ir"
496498
)
497-
test_in2_grad = IrrepLayoutUtils.transpose_irrep_layout(
498-
test_in2_grad, self.config.irreps_in2, "ir_mul", "mul_ir"
499+
for arr, irreps in zip(
500+
(test_in1_grad, test_in2_grad),
501+
(self.config.irreps_in1, self.config.irreps_in2),
499502
)
503+
]
500504

501505
for name, to_check, ground_truth, threshold in [
502506
("weight_grad", test_weights_grad, ref_weights_grad, thresh),
@@ -515,13 +519,6 @@ def test_correctness_double_backward(
515519
reference_implementation=None,
516520
high_precision_ref=False,
517521
):
518-
def maybe_transpose_input_for_test_impl(tp, x, irreps):
519-
if tp is self and tp.config.layout == "ir_mul":
520-
return IrrepLayoutUtils.transpose_irrep_layout(
521-
x, irreps, "mul_ir", "ir_mul"
522-
)
523-
return x
524-
525522
buffers = get_random_buffers_double_backward_conv(
526523
self.config, graph.node_count, graph.nnz, prng_seed
527524
)
@@ -542,6 +539,7 @@ def maybe_transpose_input_for_test_impl(tp, x, irreps):
542539
result = {"thresh": thresh}
543540
tensors = []
544541
for i, tp in enumerate([self, reference_tp]):
542+
is_test_impl = i == 0
545543
buffers_copy = [buf.copy() for buf in buffers]
546544

547545
if i == 1 and high_precision_ref:
@@ -558,21 +556,25 @@ def maybe_transpose_input_for_test_impl(tp, x, irreps):
558556
weights_dgrad, not tp.config.shared_weights
559557
)
560558

561-
db_in1 = maybe_transpose_input_for_test_impl(tp, in1, tp.config.irreps_in1)
562-
db_in2 = maybe_transpose_input_for_test_impl(tp, in2, tp.config.irreps_in2)
563-
db_out_grad = out_grad
564-
db_in1_dgrad = in1_dgrad
565-
db_in2_dgrad = in2_dgrad
566-
if tp is self and tp.config.layout == "ir_mul":
567-
db_out_grad = IrrepLayoutUtils.transpose_irrep_layout(
568-
out_grad, tp.config.irreps_out, "mul_ir", "ir_mul"
569-
)
570-
db_in1_dgrad = IrrepLayoutUtils.transpose_irrep_layout(
571-
in1_dgrad, tp.config.irreps_in1, "mul_ir", "ir_mul"
572-
)
573-
db_in2_dgrad = IrrepLayoutUtils.transpose_irrep_layout(
574-
in2_dgrad, tp.config.irreps_in2, "mul_ir", "ir_mul"
575-
)
559+
db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [
560+
buf.copy() for buf in (in1, in2, out_grad, in1_dgrad, in2_dgrad)
561+
]
562+
if is_test_impl:
563+
db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [
564+
IrrepLayoutUtils.transpose_irrep_layout(
565+
arr, irreps, "mul_ir", tp.config.layout
566+
)
567+
for arr, irreps in zip(
568+
(db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad),
569+
(
570+
tp.config.irreps_in1,
571+
tp.config.irreps_in2,
572+
tp.config.irreps_out,
573+
tp.config.irreps_in1,
574+
tp.config.irreps_in2,
575+
),
576+
)
577+
]
576578

577579
in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(
578580
db_in1,
@@ -585,16 +587,16 @@ def maybe_transpose_input_for_test_impl(tp, x, irreps):
585587
graph,
586588
)
587589

588-
if tp is self and tp.config.layout == "ir_mul":
589-
out_dgrad = IrrepLayoutUtils.transpose_irrep_layout(
590-
out_dgrad, tp.config.irreps_out, "ir_mul", "mul_ir"
591-
)
592-
in1_grad = IrrepLayoutUtils.transpose_irrep_layout(
593-
in1_grad, tp.config.irreps_in1, "ir_mul", "mul_ir"
594-
)
595-
in2_grad = IrrepLayoutUtils.transpose_irrep_layout(
596-
in2_grad, tp.config.irreps_in2, "ir_mul", "mul_ir"
597-
)
590+
if is_test_impl:
591+
out_dgrad, in1_grad, in2_grad = [
592+
IrrepLayoutUtils.transpose_irrep_layout(
593+
arr, irreps, tp.config.layout, "mul_ir"
594+
)
595+
for arr, irreps in zip(
596+
(out_dgrad, in1_grad, in2_grad),
597+
(tp.config.irreps_out, tp.config.irreps_in1, tp.config.irreps_in2),
598+
)
599+
]
598600

599601
tensors.append(
600602
(

tests/conv_test.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,9 @@ def conv_object(self, request, problem, extra_conv_constructor_args):
284284
return module.to(switch_map[problem.irrep_dtype])
285285

286286

287-
def ir_mul_representative_uvu_problems():
288-
return [
287+
class TestIrMulLayout(ConvCorrectness):
288+
production_model_tpps = mace_problems() + \
289+
[
289290
oeq.TPProblem(
290291
"5x5e",
291292
"1x3e",
@@ -306,10 +307,6 @@ def ir_mul_representative_uvu_problems():
306307
),
307308
]
308309

309-
310-
class TestIrMulLayout(ConvCorrectness):
311-
production_model_tpps = mace_problems() + ir_mul_representative_uvu_problems()
312-
313310
@pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class")
314311
def problem(self, request, dtype):
315312
problem = request.param.clone()

0 commit comments

Comments
 (0)