Skip to content

Commit 5f718a9

Browse files
Fix WaveletConv device mismatch for foi, fwhm, and tt tensors
Pass device parameter when creating foi, fwhm, and tt tensors so they are placed on the correct device at init time instead of defaulting to CPU.
1 parent 958be46 commit 5f718a9

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

spd_learn/modules/wavelet.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,22 +134,24 @@ def __init__(
134134
tmax = kernel_width_s / 2.0
135135
tmin = -tmax
136136
kernel_length = int(kernel_width_s * sfreq)
137-
self.register_buffer("tt", torch.linspace(tmin, tmax, kernel_length))
137+
self.register_buffer(
138+
"tt", torch.linspace(tmin, tmax, kernel_length, device=device)
139+
)
138140

139141
# Convert foi_init to tensor if needed
140142
if isinstance(foi_init, Tensor):
141-
foi_tensor = foi_init.detach().clone()
143+
foi_tensor = foi_init.detach().clone().to(device=device)
142144
else:
143-
foi_tensor = torch.tensor(foi_init)
145+
foi_tensor = torch.tensor(foi_init, device=device)
144146

145147
# Generate default fwhm_init if not provided, then convert to tensor
146148
if fwhm_init is None:
147149
# Default: FWHM decreases with frequency (negative values in log scale)
148150
fwhm_tensor = -foi_tensor
149151
elif isinstance(fwhm_init, Tensor):
150-
fwhm_tensor = fwhm_init.detach().clone()
152+
fwhm_tensor = fwhm_init.detach().clone().to(device=device)
151153
else:
152-
fwhm_tensor = torch.tensor(fwhm_init)
154+
fwhm_tensor = torch.tensor(fwhm_init, device=device)
153155

154156
self.foi = nn.Parameter(foi_tensor, requires_grad=True)
155157
self.fwhm = nn.Parameter(fwhm_tensor, requires_grad=True)

0 commit comments

Comments
 (0)