@@ -134,22 +134,24 @@ def __init__(
134134 tmax = kernel_width_s / 2.0
135135 tmin = - tmax
136136 kernel_length = int (kernel_width_s * sfreq )
137- self .register_buffer ("tt" , torch .linspace (tmin , tmax , kernel_length ))
137+ self .register_buffer (
138+ "tt" , torch .linspace (tmin , tmax , kernel_length , device = device )
139+ )
138140
139141 # Convert foi_init to tensor if needed
140142 if isinstance (foi_init , Tensor ):
141- foi_tensor = foi_init .detach ().clone ()
143+ foi_tensor = foi_init .detach ().clone (). to ( device = device )
142144 else :
143- foi_tensor = torch .tensor (foi_init )
145+ foi_tensor = torch .tensor (foi_init , device = device )
144146
145147 # Generate default fwhm_init if not provided, then convert to tensor
146148 if fwhm_init is None :
147149 # Default: FWHM decreases with frequency (negative values in log scale)
148150 fwhm_tensor = - foi_tensor
149151 elif isinstance (fwhm_init , Tensor ):
150- fwhm_tensor = fwhm_init .detach ().clone ()
152+ fwhm_tensor = fwhm_init .detach ().clone (). to ( device = device )
151153 else :
152- fwhm_tensor = torch .tensor (fwhm_init )
154+ fwhm_tensor = torch .tensor (fwhm_init , device = device )
153155
154156 self .foi = nn .Parameter (foi_tensor , requires_grad = True )
155157 self .fwhm = nn .Parameter (fwhm_tensor , requires_grad = True )
0 commit comments