Skip to content

Commit 958be46

Browse files
Clean up integration tests: deduplicate device params, remove redundant test
- Extract repeated device parametrize block into DEVICES constant - Remove test_module_submodules_on_device (redundant with test_module_parameters_on_device) - Remove dead float16 code path in test_module_dtype - Collapse Green/default branches in test_integration and test_integration_on_device - Move misplaced batch shapes comment to correct test
1 parent f4abd1b commit 958be46

1 file changed

Lines changed: 75 additions & 60 deletions

File tree

tests/test_integration.py

Lines changed: 75 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,24 @@
3030
),
3131
}
3232

33+
DEVICES = [
34+
"cpu",
35+
pytest.param(
36+
"cuda",
37+
marks=pytest.mark.skipif(
38+
not torch.cuda.is_available(), reason="CUDA not available"
39+
),
40+
),
41+
]
42+
3343

3444
@pytest.mark.parametrize("model_name", model_list)
3545
def test_integration(model_name):
3646
model_class = getattr(spd_learn.models, model_name)
3747

38-
params = {}
48+
params = {"sfreq": 125} if model_name == "Green" else {}
3949
if model_name == "TensorCSPNet":
40-
# TensorCSPNet requires a different input shape
4150
x = torch.randn(2, 9, 22, 1000)
42-
elif model_name == "Green":
43-
params = {"sfreq": 125}
44-
x = torch.randn(2, 22, 1000)
4551
else:
4652
x = torch.randn(2, 22, 1000)
4753

@@ -89,20 +95,7 @@ def test_module_expose_device_dtype(module_name):
8995
assert layer is not None
9096

9197

92-
# Test that all parameters of the module are on the expected device.
93-
@pytest.mark.parametrize(
94-
"device",
95-
[
96-
"cpu",
97-
pytest.param(
98-
"cuda",
99-
marks=pytest.mark.skipif(
100-
not torch.cuda.is_available(), reason="CUDA not available"
101-
),
102-
),
103-
# pytest.param("mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available (MAC only)"))
104-
],
105-
)
98+
@pytest.mark.parametrize("device", DEVICES)
10699
@pytest.mark.parametrize("module_name", module_list)
107100
def test_module_parameters_on_device(module_name, device):
108101
"""Instantiate the module on the given device and verify that each parameter is located on that device."""
@@ -119,29 +112,7 @@ def test_module_parameters_on_device(module_name, device):
119112
)
120113

121114

122-
# Optionally, test that all submodules’ parameters are on the expected device.
123-
@pytest.mark.parametrize(
124-
"device", ["cpu"]
125-
) # if you want to test submodules only on CPU in CI, or parameterize as above
126-
@pytest.mark.parametrize("module_name", module_list)
127-
def test_module_submodules_on_device(module_name, device):
128-
"""Verify that for each submodule in the module, its parameters are on the correct device."""
129-
module_class = getattr(spd_learn.modules, module_name)
130-
dtype = torch.float32
131-
mandatory_param = mandatory_parameters_per_module.get(module_name, {})
132-
133-
module = module_class(device=device, dtype=dtype, **mandatory_param)
134-
for submodule in module.modules():
135-
for name, param in submodule.named_parameters(recurse=False):
136-
assert param.device.type == device, (
137-
f"Submodule parameter '{name}' in {submodule} is on {param.device} but expected {device}"
138-
)
139-
140-
141-
# Optionally, test that all buffers are on the expected device.
142-
@pytest.mark.parametrize(
143-
"device", ["cpu"]
144-
) # if you want to test buffers only on CPU in CI, or parameterize as above
115+
@pytest.mark.parametrize("device", ["cpu"])
145116
@pytest.mark.parametrize("module_name", module_list)
146117
def test_module_buffers_on_device(module_name, device):
147118
"""Verify that all buffers in the module are on the correct device."""
@@ -156,18 +127,7 @@ def test_module_buffers_on_device(module_name, device):
156127
)
157128

158129

159-
@pytest.mark.parametrize(
160-
"device",
161-
[
162-
"cpu",
163-
pytest.param(
164-
"cuda",
165-
marks=pytest.mark.skipif(
166-
not torch.cuda.is_available(), reason="CUDA not available"
167-
),
168-
),
169-
],
170-
)
130+
@pytest.mark.parametrize("device", DEVICES)
171131
@pytest.mark.parametrize(
172132
"dtype",
173133
[torch.float32, torch.float64, torch.complex64, torch.complex128],
@@ -213,18 +173,73 @@ def test_module_dtype(module_name, dtype, device):
213173
x = torch.randn(2, 10, 1000, dtype=dtype)
214174
x = CovLayer(device=device, dtype=dtype)(x)
215175

216-
# checking if torch.linalg.eigh is available
217-
if dtype == torch.float16:
218-
with pytest.raises(RuntimeError):
219-
with torch.no_grad():
220-
out = module(x)
221-
222176
with torch.no_grad():
223177
out = module(x)
224178

225179
assert out.dtype == dtype
226180

227181

182+
@pytest.mark.parametrize("device", DEVICES)
183+
@pytest.mark.parametrize("module_name", module_list)
184+
def test_module_output_device(module_name, device):
185+
"""Run a forward pass and verify the output tensor is on the expected device."""
186+
if module_name == "PositiveDefiniteScalar":
187+
pytest.skip(
188+
"PositiveDefiniteScalar is a scalar parametrization, not a matrix layer."
189+
)
190+
191+
dtype = torch.float32
192+
module_class = getattr(spd_learn.modules, module_name)
193+
mandatory_param = mandatory_parameters_per_module.get(module_name, {})
194+
module = module_class(device=device, dtype=dtype, **mandatory_param)
195+
196+
if module_name in ("CovLayer", "WaveletConv"):
197+
x = torch.randn(2, 10, 1000, device=device, dtype=dtype)
198+
elif module_name == "LogEuclideanResidual":
199+
raw = torch.randn(2, 10, 1000, device=device, dtype=dtype)
200+
cov = CovLayer(device=device, dtype=dtype)
201+
x = cov(raw)
202+
y = cov(torch.randn(2, 10, 1000, device=device, dtype=dtype))
203+
with torch.no_grad():
204+
out = module(x, y)
205+
assert out.device.type == device, (
206+
f"Output is on {out.device} but expected {device}"
207+
)
208+
return
209+
else:
210+
raw = torch.randn(2, 10, 1000, device=device, dtype=dtype)
211+
x = CovLayer(device=device, dtype=dtype)(raw)
212+
213+
with torch.no_grad():
214+
out = module(x)
215+
216+
assert out.device.type == device, (
217+
f"Output is on {out.device} but expected {device}"
218+
)
219+
220+
221+
@pytest.mark.parametrize("device", DEVICES)
222+
@pytest.mark.parametrize("model_name", model_list)
223+
def test_integration_on_device(model_name, device):
224+
"""Create a model, move it to the target device, and verify output shape and device."""
225+
params = {"sfreq": 125} if model_name == "Green" else {}
226+
if model_name == "TensorCSPNet":
227+
x = torch.randn(2, 9, 22, 1000, device=device)
228+
else:
229+
x = torch.randn(2, 22, 1000, device=device)
230+
231+
model = getattr(spd_learn.models, model_name)(n_chans=22, n_outputs=2, **params)
232+
model.to(device)
233+
234+
with torch.no_grad():
235+
out = model(x)
236+
237+
assert out.shape == (2, 2), f"Expected shape (2, 2) but got {out.shape}"
238+
assert out.device.type == device, (
239+
f"Output is on {out.device} but expected {device}"
240+
)
241+
242+
228243
# Batch shapes to test broadcast compatibility
229244
@pytest.mark.parametrize(
230245
"extra_dim",

0 commit comments

Comments
 (0)