Skip to content

Commit 24e0209

Browse files
Merge pull request #16 from timnaher/fix/cuda-device-mismatch
Fix CUDA device mismatch in WaveletConv and ledoit_wolf
2 parents db1f8b6 + e75a9d2 commit 24e0209

3 files changed

Lines changed: 81 additions & 67 deletions

File tree

spd_learn/functional/regularize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def ledoit_wolf(
7979
alpha = torch.sigmoid(shrinkage)
8080

8181
return (1 - alpha)[..., None, None] * covariances + alpha[..., None, None] * (
82-
mu[..., None] * shrink_mat
82+
mu[..., None]
83+
* shrink_mat.to(device=covariances.device, dtype=covariances.dtype)
8384
)
8485

8586

spd_learn/modules/wavelet.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,22 +134,24 @@ def __init__(
134134
tmax = kernel_width_s / 2.0
135135
tmin = -tmax
136136
kernel_length = int(kernel_width_s * sfreq)
137-
self.register_buffer("tt", torch.linspace(tmin, tmax, kernel_length))
137+
self.register_buffer(
138+
"tt", torch.linspace(tmin, tmax, kernel_length, device=device)
139+
)
138140

139141
# Convert foi_init to tensor if needed
140142
if isinstance(foi_init, Tensor):
141-
foi_tensor = foi_init.detach().clone()
143+
foi_tensor = foi_init.detach().clone().to(device=device)
142144
else:
143-
foi_tensor = torch.tensor(foi_init)
145+
foi_tensor = torch.tensor(foi_init, device=device)
144146

145147
# Generate default fwhm_init if not provided, then convert to tensor
146148
if fwhm_init is None:
147149
# Default: FWHM decreases with frequency (negative values in log scale)
148150
fwhm_tensor = -foi_tensor
149151
elif isinstance(fwhm_init, Tensor):
150-
fwhm_tensor = fwhm_init.detach().clone()
152+
fwhm_tensor = fwhm_init.detach().clone().to(device=device)
151153
else:
152-
fwhm_tensor = torch.tensor(fwhm_init)
154+
fwhm_tensor = torch.tensor(fwhm_init, device=device)
153155

154156
self.foi = nn.Parameter(foi_tensor, requires_grad=True)
155157
self.fwhm = nn.Parameter(fwhm_tensor, requires_grad=True)
@@ -203,4 +205,4 @@ def forward(self, X: Tensor) -> Tensor:
203205
n_batch, n_freqs, n_sensors, n_epochs, n_times = X_conv.shape
204206
X_conv = X_conv.view(n_batch, n_freqs, n_sensors, n_epochs * n_times)
205207

206-
return X_conv.to(device=self.device, dtype=self.dtype)
208+
return X_conv.to(device=X.device, dtype=self.dtype)

tests/test_integration.py

Lines changed: 71 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,24 @@
3030
),
3131
}
3232

33+
DEVICES = [
34+
"cpu",
35+
pytest.param(
36+
"cuda",
37+
marks=pytest.mark.skipif(
38+
not torch.cuda.is_available(), reason="CUDA not available"
39+
),
40+
),
41+
]
42+
3343

3444
@pytest.mark.parametrize("model_name", model_list)
3545
def test_integration(model_name):
3646
model_class = getattr(spd_learn.models, model_name)
3747

38-
params = {}
48+
params = {"sfreq": 125} if model_name == "Green" else {}
3949
if model_name == "TensorCSPNet":
40-
# TensorCSPNet requires a different input shape
4150
x = torch.randn(2, 9, 22, 1000)
42-
elif model_name == "Green":
43-
params = {"sfreq": 125}
44-
x = torch.randn(2, 22, 1000)
4551
else:
4652
x = torch.randn(2, 22, 1000)
4753

@@ -89,20 +95,7 @@ def test_module_expose_device_dtype(module_name):
8995
assert layer is not None
9096

9197

92-
# Test that all parameters of the module are on the expected device.
93-
@pytest.mark.parametrize(
94-
"device",
95-
[
96-
"cpu",
97-
pytest.param(
98-
"cuda",
99-
marks=pytest.mark.skipif(
100-
not torch.cuda.is_available(), reason="CUDA not available"
101-
),
102-
),
103-
# pytest.param("mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available (MAC only)"))
104-
],
105-
)
98+
@pytest.mark.parametrize("device", DEVICES)
10699
@pytest.mark.parametrize("module_name", module_list)
107100
def test_module_parameters_on_device(module_name, device):
108101
"""Instantiate the module on the given device and verify that each parameter is located on that device."""
@@ -119,29 +112,7 @@ def test_module_parameters_on_device(module_name, device):
119112
)
120113

121114

122-
# Optionally, test that all submodules’ parameters are on the expected device.
123-
@pytest.mark.parametrize(
124-
"device", ["cpu"]
125-
) # if you want to test submodules only on CPU in CI, or parameterize as above
126-
@pytest.mark.parametrize("module_name", module_list)
127-
def test_module_submodules_on_device(module_name, device):
128-
"""Verify that for each submodule in the module, its parameters are on the correct device."""
129-
module_class = getattr(spd_learn.modules, module_name)
130-
dtype = torch.float32
131-
mandatory_param = mandatory_parameters_per_module.get(module_name, {})
132-
133-
module = module_class(device=device, dtype=dtype, **mandatory_param)
134-
for submodule in module.modules():
135-
for name, param in submodule.named_parameters(recurse=False):
136-
assert param.device.type == device, (
137-
f"Submodule parameter '{name}' in {submodule} is on {param.device} but expected {device}"
138-
)
139-
140-
141-
# Optionally, test that all buffers are on the expected device.
142-
@pytest.mark.parametrize(
143-
"device", ["cpu"]
144-
) # if you want to test buffers only on CPU in CI, or parameterize as above
115+
@pytest.mark.parametrize("device", ["cpu"])
145116
@pytest.mark.parametrize("module_name", module_list)
146117
def test_module_buffers_on_device(module_name, device):
147118
"""Verify that all buffers in the module are on the correct device."""
@@ -156,18 +127,7 @@ def test_module_buffers_on_device(module_name, device):
156127
)
157128

158129

159-
@pytest.mark.parametrize(
160-
"device",
161-
[
162-
"cpu",
163-
pytest.param(
164-
"cuda",
165-
marks=pytest.mark.skipif(
166-
not torch.cuda.is_available(), reason="CUDA not available"
167-
),
168-
),
169-
],
170-
)
130+
@pytest.mark.parametrize("device", DEVICES)
171131
@pytest.mark.parametrize(
172132
"dtype",
173133
[torch.float32, torch.float64, torch.complex64, torch.complex128],
@@ -213,18 +173,69 @@ def test_module_dtype(module_name, dtype, device):
213173
x = torch.randn(2, 10, 1000, dtype=dtype)
214174
x = CovLayer(device=device, dtype=dtype)(x)
215175

216-
# checking if torch.linalg.eigh is available
217-
if dtype == torch.float16:
218-
with pytest.raises(RuntimeError):
219-
with torch.no_grad():
220-
out = module(x)
221-
222176
with torch.no_grad():
223177
out = module(x)
224178

225179
assert out.dtype == dtype
226180

227181

182+
@pytest.mark.parametrize("device", DEVICES)
183+
@pytest.mark.parametrize("module_name", module_list)
184+
def test_module_output_device(module_name, device):
185+
"""Run a forward pass and verify the output tensor is on the expected device."""
186+
if module_name == "PositiveDefiniteScalar":
187+
pytest.skip(
188+
"PositiveDefiniteScalar is a scalar parametrization, not a matrix layer."
189+
)
190+
191+
dtype = torch.float32
192+
module_class = getattr(spd_learn.modules, module_name)
193+
mandatory_param = mandatory_parameters_per_module.get(module_name, {})
194+
module = module_class(device=device, dtype=dtype, **mandatory_param)
195+
196+
if module_name in ("CovLayer", "WaveletConv"):
197+
x = torch.randn(2, 10, 1000, device=device, dtype=dtype)
198+
elif module_name == "LogEuclideanResidual":
199+
raw = torch.randn(2, 10, 1000, device=device, dtype=dtype)
200+
cov = CovLayer(device=device, dtype=dtype)
201+
x = cov(raw)
202+
y = cov(torch.randn(2, 10, 1000, device=device, dtype=dtype))
203+
with torch.no_grad():
204+
out = module(x, y)
205+
assert out.device.type == device, (
206+
f"Output is on {out.device} but expected {device}"
207+
)
208+
return
209+
else:
210+
raw = torch.randn(2, 10, 1000, device=device, dtype=dtype)
211+
x = CovLayer(device=device, dtype=dtype)(raw)
212+
213+
with torch.no_grad():
214+
out = module(x)
215+
216+
assert out.device.type == device, f"Output is on {out.device} but expected {device}"
217+
218+
219+
@pytest.mark.parametrize("device", DEVICES)
220+
@pytest.mark.parametrize("model_name", model_list)
221+
def test_integration_on_device(model_name, device):
222+
"""Create a model, move it to the target device, and verify output shape and device."""
223+
params = {"sfreq": 125} if model_name == "Green" else {}
224+
if model_name == "TensorCSPNet":
225+
x = torch.randn(2, 9, 22, 1000, device=device)
226+
else:
227+
x = torch.randn(2, 22, 1000, device=device)
228+
229+
model = getattr(spd_learn.models, model_name)(n_chans=22, n_outputs=2, **params)
230+
model.to(device)
231+
232+
with torch.no_grad():
233+
out = model(x)
234+
235+
assert out.shape == (2, 2), f"Expected shape (2, 2) but got {out.shape}"
236+
assert out.device.type == device, f"Output is on {out.device} but expected {device}"
237+
238+
228239
# Batch shapes to test broadcast compatibility
229240
@pytest.mark.parametrize(
230241
"extra_dim",

0 commit comments

Comments
 (0)