Skip to content

Commit c78c8fe

Browse files
committed
EXA use himalaya.backend.set_backend(..., on_error='warn')
1 parent 89c6b0a commit c78c8fe

4 files changed

Lines changed: 21 additions & 12 deletions

File tree

tutorials/movies_3T/02_plot_wordnet_model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,19 @@
134134
# target case, ``GridSearchCV`` can only optimize e.g. the mean score over
135135
# targets. Here, we want to find a different optimal hyperparameter per
136136
# target/voxel, so we use ``himalaya``'s ``KernelRidgeCV`` instead.
137-
# Moreover, ``himalaya`` implements different computational backends; we will
138-
# use the "torch_cuda" backend to use fast computations on GPU.
139137
from himalaya.kernel_ridge import KernelRidgeCV
140138

139+
###############################################################################
140+
# Moreover, ``himalaya`` implements different computational backends, including
141+
# GPU backends. The available GPU backends are "torch_cuda" and "cupy". (These
142+
# backends are only available if you installed the corresponding package with
143+
# CUDA enabled. Check the pytorch/cupy documentation for install instructions.)
144+
#
145+
# Here we use the "torch_cuda" backend, but if the import fails we continue
146+
# with the default "numpy" backend. The "numpy" backend is expected to be
147+
# slower since it only uses the CPU.
141148
from himalaya.backend import set_backend
142-
backend = set_backend("torch_cuda")
149+
backend = set_backend("torch_cuda", on_error="warn")
143150

144151
###############################################################################
145152
# The scale of the regularization hyperparameter ``alpha`` is unknown, so we

tutorials/movies_3T/03_plot_motion_energy_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@
106106
from voxelwise_tutorials.delayer import Delayer
107107
from himalaya.kernel_ridge import KernelRidgeCV
108108
from himalaya.backend import set_backend
109-
backend = set_backend("torch_cuda")
109+
backend = set_backend("torch_cuda", on_error="warn")
110110

111111
alphas = np.logspace(1, 20, 20)
112112

tutorials/movies_3T/04_plot_banded_ridge_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
from sklearn.preprocessing import StandardScaler
102102
from voxelwise_tutorials.delayer import Delayer
103103
from himalaya.backend import set_backend
104-
backend = set_backend("torch_cuda")
104+
backend = set_backend("torch_cuda", on_error="warn")
105105

106106
###############################################################################
107107
# To fit the banded ridge model, we use ``himalaya``'s

tutorials/movies_4T/02_plot_ridge_model.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,17 @@
150150
from voxelwise_tutorials.delayer import Delayer
151151

152152
###############################################################################
153-
# We set himalaya's backend to "torch_cuda" to fit the model using GPU.
154-
# The available backends are:
153+
# The package``himalaya`` implements different computational backends,
154+
# including GPU backends. The available GPU backends are "torch_cuda" and
155+
# "cupy". (These backends are only available if you installed the corresponding
156+
# package with CUDA enabled. Check the pytorch/cupy documentation for install
157+
# instructions.)
155158
#
156-
# - "numpy" (CPU) (default)
157-
# - "torch" (CPU)
158-
# - "torch_cuda" (GPU)
159-
# - "cupy" (GPU)
159+
# Here we use the "torch_cuda" backend, but if the import fails we continue
160+
# with the default "numpy" backend. The "numpy" backend is expected to be
161+
# slower since it only uses the CPU.
160162
from himalaya.backend import set_backend
161-
backend = set_backend("torch_cuda")
163+
backend = set_backend("torch_cuda", on_error="warn")
162164

163165
###############################################################################
164166
# The scale of the regularization hyperparameter alpha is unknown, so we use

0 commit comments

Comments
 (0)