Skip to content

Commit 07bdcae

Browse files
committed
- Fix CUDA device mismatch in WaveletConv and ledoit_wolf
- WaveletConv.forward hardcoded self.device (defaults to cpu), ignoring the input tensor's device. ledoit_wolf didn't coerce shrink_mat to the input device/dtype. Both caused RuntimeError when running on gpu.
1 parent db1f8b6 commit 07bdcae

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

spd_learn/functional/regularize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def ledoit_wolf(
7979
alpha = torch.sigmoid(shrinkage)
8080

8181
return (1 - alpha)[..., None, None] * covariances + alpha[..., None, None] * (
82-
mu[..., None] * shrink_mat
82+
mu[..., None] * shrink_mat.to(device=covariances.device, dtype=covariances.dtype)
8383
)
8484

8585

spd_learn/modules/wavelet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,4 +203,4 @@ def forward(self, X: Tensor) -> Tensor:
203203
n_batch, n_freqs, n_sensors, n_epochs, n_times = X_conv.shape
204204
X_conv = X_conv.view(n_batch, n_freqs, n_sensors, n_epochs * n_times)
205205

206-
return X_conv.to(device=self.device, dtype=self.dtype)
206+
return X_conv.to(device=X.device, dtype=self.dtype)

0 commit comments

Comments
 (0)