Skip to content

Commit 4806748

Browse files
committed
Fixed more stuff.
1 parent 7dcd887 commit 4806748

3 files changed

Lines changed: 70 additions & 2 deletions

File tree

openequivariance/openequivariance/core/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ def filter_and_analyze_problem(problem):
9696
f"Connection mode must be 'uvu' or 'uvw', got {problem.instructions[0].connection_mode}"
9797
)
9898

99+
if problem.layout == "ir_mul":
100+
assert problem.instructions[0].connection_mode == "uvu", (
101+
"layout='ir_mul' is only supported for pure 'uvu' problems"
102+
)
103+
99104
assert problem.irrep_dtype == problem.weight_dtype, (
100105
f"irrep_dtype and weight_dtype must be the same, got {problem.irrep_dtype} and {problem.weight_dtype}"
101106
)

tests/batch_test.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,31 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
273273
return tp, tp.config
274274

275275

276+
def ir_mul_representative_uvu_problems():
277+
return [
278+
oeq.TPProblem(
279+
"5x5e",
280+
"1x3e",
281+
"5x5e",
282+
[(0, 0, 0, "uvu", True)],
283+
shared_weights=False,
284+
internal_weights=False,
285+
label="ir_mul_repr_5x1x5_l535",
286+
),
287+
oeq.TPProblem(
288+
"13x5e",
289+
"1x3e",
290+
"13x5e",
291+
[(0, 0, 0, "uvu", True)],
292+
shared_weights=False,
293+
internal_weights=False,
294+
label="ir_mul_repr_13x1x13_l535",
295+
),
296+
]
297+
298+
276299
class TestIrMulLayoutMACE(TPCorrectness):
277-
production_model_tpps = mace_problems()
300+
production_model_tpps = mace_problems() + ir_mul_representative_uvu_problems()
278301

279302
@pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class")
280303
def problem(self, request, dtype):
@@ -284,6 +307,23 @@ def problem(self, request, dtype):
284307
return problem
285308

286309

310+
def test_ir_mul_rejects_uvw_problem(dtype):
311+
problem = oeq.TPProblem(
312+
"5x5e",
313+
"1x3e",
314+
"5x5e",
315+
[(0, 0, 0, "uvw", True)],
316+
shared_weights=False,
317+
internal_weights=False,
318+
irrep_dtype=dtype,
319+
weight_dtype=dtype,
320+
layout="ir_mul",
321+
)
322+
323+
with pytest.raises(AssertionError, match="layout='ir_mul'"):
324+
oeq.TensorProduct(problem)
325+
326+
287327
class TestTorchToSubmodule:
288328
"""Test that TensorProduct works correctly as a submodule when parent's .to() is called"""
289329

tests/conv_test.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,31 @@ def conv_object(self, request, problem, extra_conv_constructor_args):
284284
return module.to(switch_map[problem.irrep_dtype])
285285

286286

287+
def ir_mul_representative_uvu_problems():
288+
return [
289+
oeq.TPProblem(
290+
"5x5e",
291+
"1x3e",
292+
"5x5e",
293+
[(0, 0, 0, "uvu", True)],
294+
shared_weights=False,
295+
internal_weights=False,
296+
label="ir_mul_repr_5x1x5_l535",
297+
),
298+
oeq.TPProblem(
299+
"13x5e",
300+
"1x3e",
301+
"13x5e",
302+
[(0, 0, 0, "uvu", True)],
303+
shared_weights=False,
304+
internal_weights=False,
305+
label="ir_mul_repr_13x1x13_l535",
306+
),
307+
]
308+
309+
287310
class TestIrMulLayout(ConvCorrectness):
288-
production_model_tpps = mace_problems()
311+
production_model_tpps = mace_problems() + ir_mul_representative_uvu_problems()
289312

290313
@pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class")
291314
def problem(self, request, dtype):

0 commit comments

Comments
 (0)