Skip to content

Commit 379fd28

Browse files
committed
Convolution test is failing.
1 parent 85e988f commit 379fd28

5 files changed

Lines changed: 130 additions & 33 deletions

File tree

openequivariance/openequivariance/core/ComputationSchedule.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def __init__(self, src_irreps, src_views, idxs):
2929
src_ranges = [src_irreps.slices()[idx] for idx in self.src_dst_map]
3030
dst_ranges = [self.dst_irreps.slices()[i] for i in self.src_dst_map.values()]
3131

32+
self.storeback_procedure = {idx: "write" for idx in self.idxs}
33+
self.persist_load = False
34+
self.persist_store = False
35+
3236
if src_views[0].layout == "ir_mul":
3337
return
3438

@@ -55,11 +59,6 @@ def __init__(self, src_irreps, src_views, idxs):
5559
self.dst_ranges.append(slice(dst_start, dst_end))
5660
self.copy_ranges = list(zip(self.src_ranges, self.dst_ranges))
5761

58-
self.persist_load = False
59-
self.persist_store = False
60-
61-
self.storeback_procedure = {idx: "write" for idx in self.idxs}
62-
6362

6463
class CGTensor:
6564
def __init__(self, l1, l2, l3, normalization_factor, dtype):

openequivariance/openequivariance/core/ConvolutionBase.py

Lines changed: 102 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import copy
2+
23
import numpy as np
4+
5+
from openequivariance.benchmark.correctness_utils import check_similiarity
6+
from openequivariance.benchmark.logging_utils import bcolors, getLogger
37
from openequivariance.benchmark.random_buffer_utils import (
4-
get_random_buffers_forward_conv,
58
get_random_buffers_backward_conv,
69
get_random_buffers_double_backward_conv,
10+
get_random_buffers_forward_conv,
711
)
8-
9-
from openequivariance.benchmark.logging_utils import getLogger, bcolors
10-
from openequivariance.benchmark.correctness_utils import check_similiarity
1112
from openequivariance.core.e3nn_lite import wigner_3j
12-
from openequivariance.core.utils import benchmark
13+
from openequivariance.core.utils import IrrepLayoutUtils, benchmark
1314

1415
logger = getLogger()
1516

@@ -143,6 +144,13 @@ def test_correctness_forward(
143144
check_reproducible=True,
144145
high_precision_ref=False,
145146
):
147+
def maybe_transpose_input_for_test_impl(x, irreps):
148+
if self.config.layout == "ir_mul":
149+
return IrrepLayoutUtils.transpose_irrep_layout(
150+
x, irreps, "mul_ir", "ir_mul"
151+
)
152+
return x
153+
146154
if reference_implementation is None:
147155
from openequivariance._torch.E3NNConv import E3NNConv
148156

@@ -186,13 +194,22 @@ def test_correctness_forward(
186194

187195
test_out = out.copy()
188196
self.forward_cpu(
189-
L1_in=in1.copy(),
190-
L2_in=in2.copy(),
197+
L1_in=maybe_transpose_input_for_test_impl(
198+
in1.copy(), self.config.irreps_in1
199+
),
200+
L2_in=maybe_transpose_input_for_test_impl(
201+
in2.copy(), self.config.irreps_in2
202+
),
191203
weights=weights.copy(),
192204
L3_out=test_out,
193205
graph=graph,
194206
)
195207

208+
if self.config.layout == "ir_mul":
209+
test_out = IrrepLayoutUtils.transpose_irrep_layout(
210+
test_out, self.config.irreps_out, "ir_mul", "mul_ir"
211+
)
212+
196213
for name, to_check, ground_truth in [("output", ref_out, test_out)]:
197214
result[name] = check_similiarity(name, to_check, ground_truth, thresh)
198215

@@ -205,13 +222,22 @@ def test_correctness_forward(
205222
for i in range(num_trials):
206223
repeated_run = out.copy()
207224
self.forward_cpu(
208-
L1_in=in1.copy(),
209-
L2_in=in2.copy(),
225+
L1_in=maybe_transpose_input_for_test_impl(
226+
in1.copy(), self.config.irreps_in1
227+
),
228+
L2_in=maybe_transpose_input_for_test_impl(
229+
in2.copy(), self.config.irreps_in2
230+
),
210231
weights=weights.copy(),
211232
L3_out=repeated_run,
212233
graph=graph,
213234
)
214235

236+
if self.config.layout == "ir_mul":
237+
repeated_run = IrrepLayoutUtils.transpose_irrep_layout(
238+
repeated_run, self.config.irreps_out, "ir_mul", "mul_ir"
239+
)
240+
215241
for name, to_check, ground_truth in [
216242
("output", repeated_run, test_out)
217243
]:
@@ -387,6 +413,13 @@ def test_correctness_backward(
387413
reference_implementation=None,
388414
high_precision_ref=False,
389415
):
416+
def maybe_transpose_input_for_test_impl(x, irreps):
417+
if self.config.layout == "ir_mul":
418+
return IrrepLayoutUtils.transpose_irrep_layout(
419+
x, irreps, "mul_ir", "ir_mul"
420+
)
421+
return x
422+
390423
if reference_implementation is None:
391424
from openequivariance._torch.E3NNConv import E3NNConv
392425

@@ -436,17 +469,35 @@ def test_correctness_backward(
436469
test_in1_grad = in1_grad.copy()
437470
test_in2_grad = in2_grad.copy()
438471

472+
test_L3_grad = out_grad.copy()
473+
if self.config.layout == "ir_mul":
474+
test_L3_grad = IrrepLayoutUtils.transpose_irrep_layout(
475+
test_L3_grad, self.config.irreps_out, "mul_ir", "ir_mul"
476+
)
477+
439478
self.backward_cpu(
440-
L1_in=in1.copy(),
479+
L1_in=maybe_transpose_input_for_test_impl(
480+
in1.copy(), self.config.irreps_in1
481+
),
441482
L1_grad=test_in1_grad,
442-
L2_in=in2.copy(),
483+
L2_in=maybe_transpose_input_for_test_impl(
484+
in2.copy(), self.config.irreps_in2
485+
),
443486
L2_grad=test_in2_grad,
444-
L3_grad=out_grad.copy(),
487+
L3_grad=test_L3_grad,
445488
weights=weights.copy(),
446489
weights_grad=test_weights_grad,
447490
graph=graph,
448491
)
449492

493+
if self.config.layout == "ir_mul":
494+
test_in1_grad = IrrepLayoutUtils.transpose_irrep_layout(
495+
test_in1_grad, self.config.irreps_in1, "ir_mul", "mul_ir"
496+
)
497+
test_in2_grad = IrrepLayoutUtils.transpose_irrep_layout(
498+
test_in2_grad, self.config.irreps_in2, "ir_mul", "mul_ir"
499+
)
500+
450501
for name, to_check, ground_truth, threshold in [
451502
("weight_grad", test_weights_grad, ref_weights_grad, thresh),
452503
("in1_grad", test_in1_grad, ref_in1_grad, thresh),
@@ -464,6 +515,13 @@ def test_correctness_double_backward(
464515
reference_implementation=None,
465516
high_precision_ref=False,
466517
):
518+
def maybe_transpose_input_for_test_impl(tp, x, irreps):
519+
if tp is self and tp.config.layout == "ir_mul":
520+
return IrrepLayoutUtils.transpose_irrep_layout(
521+
x, irreps, "mul_ir", "ir_mul"
522+
)
523+
return x
524+
467525
buffers = get_random_buffers_double_backward_conv(
468526
self.config, graph.node_count, graph.nnz, prng_seed
469527
)
@@ -500,17 +558,44 @@ def test_correctness_double_backward(
500558
weights_dgrad, not tp.config.shared_weights
501559
)
502560

561+
db_in1 = maybe_transpose_input_for_test_impl(tp, in1, tp.config.irreps_in1)
562+
db_in2 = maybe_transpose_input_for_test_impl(tp, in2, tp.config.irreps_in2)
563+
db_out_grad = out_grad
564+
db_in1_dgrad = in1_dgrad
565+
db_in2_dgrad = in2_dgrad
566+
if tp is self and tp.config.layout == "ir_mul":
567+
db_out_grad = IrrepLayoutUtils.transpose_irrep_layout(
568+
out_grad, tp.config.irreps_out, "mul_ir", "ir_mul"
569+
)
570+
db_in1_dgrad = IrrepLayoutUtils.transpose_irrep_layout(
571+
in1_dgrad, tp.config.irreps_in1, "mul_ir", "ir_mul"
572+
)
573+
db_in2_dgrad = IrrepLayoutUtils.transpose_irrep_layout(
574+
in2_dgrad, tp.config.irreps_in2, "mul_ir", "ir_mul"
575+
)
576+
503577
in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(
504-
in1,
505-
in2,
506-
out_grad,
578+
db_in1,
579+
db_in2,
580+
db_out_grad,
507581
weights_reordered,
508582
weights_dgrad_reordered,
509-
in1_dgrad,
510-
in2_dgrad,
583+
db_in1_dgrad,
584+
db_in2_dgrad,
511585
graph,
512586
)
513587

588+
if tp is self and tp.config.layout == "ir_mul":
589+
out_dgrad = IrrepLayoutUtils.transpose_irrep_layout(
590+
out_dgrad, tp.config.irreps_out, "ir_mul", "mul_ir"
591+
)
592+
in1_grad = IrrepLayoutUtils.transpose_irrep_layout(
593+
in1_grad, tp.config.irreps_in1, "ir_mul", "mul_ir"
594+
)
595+
in2_grad = IrrepLayoutUtils.transpose_irrep_layout(
596+
in2_grad, tp.config.irreps_in2, "ir_mul", "mul_ir"
597+
)
598+
514599
tensors.append(
515600
(
516601
out_dgrad,

openequivariance/openequivariance/core/LoopUnrollConv.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
import numpy as np
21
import json
32

4-
from openequivariance.core.ConvolutionBase import ConvolutionBase
3+
import numpy as np
4+
55
from openequivariance.core.ComputationSchedule import (
66
ComputationSchedule,
77
SMEMCapacityException,
88
)
9-
10-
from openequivariance.templates.jinja_utils import get_jinja_environment
9+
from openequivariance.core.ConvolutionBase import ConvolutionBase
1110
from openequivariance.core.utils import (
12-
filter_and_analyze_problem,
1311
dtype_to_enum,
12+
filter_and_analyze_problem,
1413
hash_str_64,
1514
)
15+
from openequivariance.templates.jinja_utils import get_jinja_environment
1616

1717

1818
class LoopUnrollConv(ConvolutionBase):
@@ -114,9 +114,11 @@ def generate_double_backward_schedule(warps_per_block):
114114
except SMEMCapacityException:
115115
warp_count -= 1
116116
if warp_count == 0:
117-
raise SMEMCapacityException(
117+
raise RuntimeError(
118118
"Tensor product schedule generation failed, shared memory inadequate!"
119119
)
120+
except Exception:
121+
raise
120122

121123
if not deterministic:
122124
for segment in self.forward_schedule.segments:

openequivariance/openequivariance/templates/macros.jinja

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray.
9191
{%- set dim = src_mul_ir.ir.dim %}
9292
{%- set mul = src_mul_ir.mul %}
9393
{%- for i in range(dim) %}
94-
ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];)
94+
ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];)
9595
{%- endfor %}
9696
{%- endfor %}
9797
{%- endif %}
@@ -113,7 +113,7 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray.
113113
{%- set dim = src_mul_ir.ir.dim %}
114114
{%- set mul = src_mul_ir.mul %}
115115
{%- for i in range(dim) %}
116-
ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];)
116+
ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];)
117117
{%- endfor %}
118118
{%- endfor %}
119119
{%- endif %}
@@ -144,15 +144,15 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray.
144144
{%- set mul = src_mul_ir.mul %}
145145
{%- if map.storeback_procedure[idx] == "write" %}
146146
{%- for i in range(dim) %}
147-
ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] = {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id];)
147+
ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] = {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id];)
148148
{%- endfor %}
149149
{%- elif map.storeback_procedure[idx] == "accumulate" %}
150150
{%- for i in range(dim) %}
151-
ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] += {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id];)
151+
ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] += {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id];)
152152
{%- endfor %}
153153
{%- elif map.storeback_procedure[idx] == "atomic_accumulate" %}
154154
{%- for i in range(dim) %}
155-
ROW_OPERATION({{mul}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id]);)
155+
ROW_OPERATION({{mul}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id]);)
156156
{%- endfor %}
157157
{%- endif %}
158158
{%- endfor %}

tests/conv_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,17 @@ def conv_object(self, request, problem, extra_conv_constructor_args):
284284
return module.to(switch_map[problem.irrep_dtype])
285285

286286

287+
class TestIrMulLayout(ConvCorrectness):
288+
production_model_tpps = mace_problems()
289+
290+
@pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class")
291+
def problem(self, request, dtype):
292+
problem = request.param.clone()
293+
problem.irrep_dtype, problem.weight_dtype = dtype, dtype
294+
problem.layout = "ir_mul"
295+
return problem
296+
297+
287298
class TestTorchToSubmodule:
288299
"""Test that TensorProductConv works as a submodule when parent's .to() is called"""
289300

0 commit comments

Comments
 (0)