Skip to content

Commit 356962f

Browse files
Fix Sphinx-Gallery format and use local imports in SPDIM tutorial
Convert Colab notebook format to proper Sphinx-Gallery style and distribute imports locally near their usage for better readability.
1 parent 26181a8 commit 356962f

1 file changed

Lines changed: 30 additions & 22 deletions

File tree

examples/applied_examples/plot_source_free_domain.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -82,28 +82,6 @@
8282
import matplotlib.pyplot as plt
8383
import numpy as np
8484
import torch
85-
import torch.nn.functional as F
86-
from braindecode import EEGClassifier
87-
from braindecode.datasets import create_from_X_y
88-
from moabb.datasets import BNCI2015_001
89-
from moabb.paradigms import MotorImagery
90-
from sklearn.metrics import balanced_accuracy_score
91-
from sklearn.preprocessing import LabelEncoder
92-
from skorch.callbacks import GradientNormClipping
93-
from skorch.dataset import ValidSplit
94-
from torch.nn.utils.parametrize import register_parametrization
95-
96-
from spd_learn.functional import (
97-
get_epsilon,
98-
matrix_exp,
99-
matrix_inv_sqrt,
100-
matrix_log,
101-
matrix_power,
102-
matrix_sqrt_inv,
103-
)
104-
from spd_learn.models import TSMNet
105-
from spd_learn.modules import LogEig
106-
from spd_learn.modules.manifold import SymmetricPositiveDefinite
10785

10886
warnings.filterwarnings("ignore")
10987

@@ -115,6 +93,14 @@
11593
# These will be included in a future release of ``spd_learn.functional``.
11694
#
11795

96+
from spd_learn.functional import (
97+
get_epsilon,
98+
matrix_exp,
99+
matrix_log,
100+
matrix_power,
101+
matrix_sqrt_inv,
102+
)
103+
118104

119105
def geodesic_transport_to_identity(X, mean, t):
120106
r"""Transport SPD matrices along the geodesic toward identity.
@@ -204,6 +190,11 @@ def karcher_mean(X, max_iter=50, return_distances=False):
204190
# - **Target domain**: Session B (adaptation without labels)
205191
#
206192

193+
from braindecode.datasets import create_from_X_y
194+
from moabb.datasets import BNCI2015_001
195+
from moabb.paradigms import MotorImagery
196+
from sklearn.preprocessing import LabelEncoder
197+
207198
dataset = BNCI2015_001()
208199
paradigm = MotorImagery(
209200
n_classes=2,
@@ -320,6 +311,12 @@ def karcher_mean(X, max_iter=50, return_distances=False):
320311
# - **Validation split** (10%) for early stopping
321312
#
322313

314+
from braindecode import EEGClassifier
315+
from skorch.callbacks import GradientNormClipping
316+
from skorch.dataset import ValidSplit
317+
318+
from spd_learn.models import TSMNet
319+
323320
n_chans = X_source.shape[1]
324321
n_outputs = len(le.classes_)
325322

@@ -366,6 +363,8 @@ def karcher_mean(X, max_iter=50, return_distances=False):
366363
# Evaluate the source-trained model on target domain without adaptation.
367364
#
368365

366+
from sklearn.metrics import balanced_accuracy_score
367+
369368
underlying_model = clf.module_
370369

371370
y_pred_source = clf.predict(X_source)
@@ -391,6 +390,8 @@ def karcher_mean(X, max_iter=50, return_distances=False):
391390
# rebiasing from the trained BN layer.
392391
#
393392

393+
from spd_learn.modules import LogEig
394+
394395

395396
def extract_spd_features(model, X_data, batch_size=32):
396397
"""Extract SPD features before batch normalization."""
@@ -491,6 +492,8 @@ def refit_spdbn_karcher(model, X_data, batch_size=32):
491492
# while maintaining class diversity (high marginal entropy).
492493
#
493494

495+
import torch.nn.functional as F
496+
494497

495498
def im_loss(logits, temperature=2.0):
496499
"""Information Maximization loss (matching SPDIM paper)."""
@@ -521,6 +524,11 @@ def im_loss(logits, temperature=2.0):
521524
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
522525
#
523526

527+
from torch.nn.utils.parametrize import register_parametrization
528+
529+
from spd_learn.functional import matrix_inv_sqrt
530+
from spd_learn.modules.manifold import SymmetricPositiveDefinite
531+
524532

525533
class SPDLearnableRecenter(torch.nn.Module):
526534
def __init__(

0 commit comments

Comments
 (0)