Skip to content

Commit 95783ef

Browse files
committed
More test cleaning.
1 parent 3c9ed29 commit 95783ef

2 files changed

Lines changed: 21 additions & 25 deletions

File tree

tests/batch_test.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,9 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
272272
tp.to(switch_map[problem.irrep_dtype])
273273
return tp, tp.config
274274

275-
276-
def ir_mul_representative_uvu_problems():
277-
return [
275+
class TestIrMulLayoutMACE(TPCorrectness):
276+
production_model_tpps = mace_problems() + \
277+
[
278278
oeq.TPProblem(
279279
"5x5e",
280280
"1x3e",
@@ -293,11 +293,7 @@ def ir_mul_representative_uvu_problems():
293293
internal_weights=False,
294294
label="ir_mul_repr_13x1x13_l535",
295295
),
296-
]
297-
298-
299-
class TestIrMulLayoutMACE(TPCorrectness):
300-
production_model_tpps = mace_problems() + ir_mul_representative_uvu_problems()
296+
]
301297

302298
@pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class")
303299
def problem(self, request, dtype):
@@ -307,23 +303,6 @@ def problem(self, request, dtype):
307303
return problem
308304

309305

310-
def test_ir_mul_rejects_uvw_problem(dtype):
311-
problem = oeq.TPProblem(
312-
"5x5e",
313-
"1x3e",
314-
"5x5e",
315-
[(0, 0, 0, "uvw", True)],
316-
shared_weights=False,
317-
internal_weights=False,
318-
irrep_dtype=dtype,
319-
weight_dtype=dtype,
320-
layout="ir_mul",
321-
)
322-
323-
with pytest.raises(AssertionError, match="layout='ir_mul'"):
324-
oeq.TensorProduct(problem)
325-
326-
327306
class TestTorchToSubmodule:
328307
"""Test that TensorProduct works correctly as a submodule when parent's .to() is called"""
329308

tests/input_validation_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,20 @@ def test_cpp_checks_forward_dtype(executable_and_buffers, subtests):
138138
with pytest.raises(RuntimeError, match=r"Dtype mismatch"):
139139
buffers[i] = buffers[i].to(dtype=torch.bfloat16)
140140
executable(*buffers)
141+
142+
143+
def test_ir_mul_rejects_uvw_problem(dtype):
144+
problem = TPProblem(
145+
"5x5e",
146+
"1x3e",
147+
"5x5e",
148+
[(0, 0, 0, "uvw", True)],
149+
shared_weights=False,
150+
internal_weights=False,
151+
irrep_dtype=dtype,
152+
weight_dtype=dtype,
153+
layout="ir_mul",
154+
)
155+
156+
with pytest.raises(AssertionError, match="layout='ir_mul'"):
157+
TensorProduct(problem)

0 commit comments

Comments
 (0)