3030 ),
3131}
3232
33+ DEVICES = [
34+ "cpu" ,
35+ pytest .param (
36+ "cuda" ,
37+ marks = pytest .mark .skipif (
38+ not torch .cuda .is_available (), reason = "CUDA not available"
39+ ),
40+ ),
41+ ]
42+
3343
3444@pytest .mark .parametrize ("model_name" , model_list )
3545def test_integration (model_name ):
3646 model_class = getattr (spd_learn .models , model_name )
3747
38- params = {}
48+ params = {"sfreq" : 125 } if model_name == "Green" else { }
3949 if model_name == "TensorCSPNet" :
40- # TensorCSPNet requires a different input shape
4150 x = torch .randn (2 , 9 , 22 , 1000 )
42- elif model_name == "Green" :
43- params = {"sfreq" : 125 }
44- x = torch .randn (2 , 22 , 1000 )
4551 else :
4652 x = torch .randn (2 , 22 , 1000 )
4753
@@ -89,20 +95,7 @@ def test_module_expose_device_dtype(module_name):
8995 assert layer is not None
9096
9197
92- # Test that all parameters of the module are on the expected device.
93- @pytest .mark .parametrize (
94- "device" ,
95- [
96- "cpu" ,
97- pytest .param (
98- "cuda" ,
99- marks = pytest .mark .skipif (
100- not torch .cuda .is_available (), reason = "CUDA not available"
101- ),
102- ),
103- # pytest.param("mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available (MAC only)"))
104- ],
105- )
98+ @pytest .mark .parametrize ("device" , DEVICES )
10699@pytest .mark .parametrize ("module_name" , module_list )
107100def test_module_parameters_on_device (module_name , device ):
108101 """Instantiate the module on the given device and verify that each parameter is located on that device."""
@@ -119,29 +112,7 @@ def test_module_parameters_on_device(module_name, device):
119112 )
120113
121114
122- # Optionally, test that all submodules’ parameters are on the expected device.
123- @pytest .mark .parametrize (
124- "device" , ["cpu" ]
125- ) # if you want to test submodules only on CPU in CI, or parameterize as above
126- @pytest .mark .parametrize ("module_name" , module_list )
127- def test_module_submodules_on_device (module_name , device ):
128- """Verify that for each submodule in the module, its parameters are on the correct device."""
129- module_class = getattr (spd_learn .modules , module_name )
130- dtype = torch .float32
131- mandatory_param = mandatory_parameters_per_module .get (module_name , {})
132-
133- module = module_class (device = device , dtype = dtype , ** mandatory_param )
134- for submodule in module .modules ():
135- for name , param in submodule .named_parameters (recurse = False ):
136- assert param .device .type == device , (
137- f"Submodule parameter '{ name } ' in { submodule } is on { param .device } but expected { device } "
138- )
139-
140-
141- # Optionally, test that all buffers are on the expected device.
142- @pytest .mark .parametrize (
143- "device" , ["cpu" ]
144- ) # if you want to test buffers only on CPU in CI, or parameterize as above
115+ @pytest .mark .parametrize ("device" , ["cpu" ])
145116@pytest .mark .parametrize ("module_name" , module_list )
146117def test_module_buffers_on_device (module_name , device ):
147118 """Verify that all buffers in the module are on the correct device."""
@@ -156,18 +127,7 @@ def test_module_buffers_on_device(module_name, device):
156127 )
157128
158129
159- @pytest .mark .parametrize (
160- "device" ,
161- [
162- "cpu" ,
163- pytest .param (
164- "cuda" ,
165- marks = pytest .mark .skipif (
166- not torch .cuda .is_available (), reason = "CUDA not available"
167- ),
168- ),
169- ],
170- )
130+ @pytest .mark .parametrize ("device" , DEVICES )
171131@pytest .mark .parametrize (
172132 "dtype" ,
173133 [torch .float32 , torch .float64 , torch .complex64 , torch .complex128 ],
@@ -213,18 +173,69 @@ def test_module_dtype(module_name, dtype, device):
213173 x = torch .randn (2 , 10 , 1000 , dtype = dtype )
214174 x = CovLayer (device = device , dtype = dtype )(x )
215175
216- # checking if torch.linalg.eigh is available
217- if dtype == torch .float16 :
218- with pytest .raises (RuntimeError ):
219- with torch .no_grad ():
220- out = module (x )
221-
222176 with torch .no_grad ():
223177 out = module (x )
224178
225179 assert out .dtype == dtype
226180
227181
182+ @pytest .mark .parametrize ("device" , DEVICES )
183+ @pytest .mark .parametrize ("module_name" , module_list )
184+ def test_module_output_device (module_name , device ):
185+ """Run a forward pass and verify the output tensor is on the expected device."""
186+ if module_name == "PositiveDefiniteScalar" :
187+ pytest .skip (
188+ "PositiveDefiniteScalar is a scalar parametrization, not a matrix layer."
189+ )
190+
191+ dtype = torch .float32
192+ module_class = getattr (spd_learn .modules , module_name )
193+ mandatory_param = mandatory_parameters_per_module .get (module_name , {})
194+ module = module_class (device = device , dtype = dtype , ** mandatory_param )
195+
196+ if module_name in ("CovLayer" , "WaveletConv" ):
197+ x = torch .randn (2 , 10 , 1000 , device = device , dtype = dtype )
198+ elif module_name == "LogEuclideanResidual" :
199+ raw = torch .randn (2 , 10 , 1000 , device = device , dtype = dtype )
200+ cov = CovLayer (device = device , dtype = dtype )
201+ x = cov (raw )
202+ y = cov (torch .randn (2 , 10 , 1000 , device = device , dtype = dtype ))
203+ with torch .no_grad ():
204+ out = module (x , y )
205+ assert out .device .type == device , (
206+ f"Output is on { out .device } but expected { device } "
207+ )
208+ return
209+ else :
210+ raw = torch .randn (2 , 10 , 1000 , device = device , dtype = dtype )
211+ x = CovLayer (device = device , dtype = dtype )(raw )
212+
213+ with torch .no_grad ():
214+ out = module (x )
215+
216+ assert out .device .type == device , f"Output is on { out .device } but expected { device } "
217+
218+
219+ @pytest .mark .parametrize ("device" , DEVICES )
220+ @pytest .mark .parametrize ("model_name" , model_list )
221+ def test_integration_on_device (model_name , device ):
222+ """Create a model, move it to the target device, and verify output shape and device."""
223+ params = {"sfreq" : 125 } if model_name == "Green" else {}
224+ if model_name == "TensorCSPNet" :
225+ x = torch .randn (2 , 9 , 22 , 1000 , device = device )
226+ else :
227+ x = torch .randn (2 , 22 , 1000 , device = device )
228+
229+ model = getattr (spd_learn .models , model_name )(n_chans = 22 , n_outputs = 2 , ** params )
230+ model .to (device )
231+
232+ with torch .no_grad ():
233+ out = model (x )
234+
235+ assert out .shape == (2 , 2 ), f"Expected shape (2, 2) but got { out .shape } "
236+ assert out .device .type == device , f"Output is on { out .device } but expected { device } "
237+
238+
228239# Batch shapes to test broadcast compatibility
229240@pytest .mark .parametrize (
230241 "extra_dim" ,
0 commit comments