8282import matplotlib .pyplot as plt
8383import numpy as np
8484import torch
85- import torch .nn .functional as F
86- from braindecode import EEGClassifier
87- from braindecode .datasets import create_from_X_y
88- from moabb .datasets import BNCI2015_001
89- from moabb .paradigms import MotorImagery
90- from sklearn .metrics import balanced_accuracy_score
91- from sklearn .preprocessing import LabelEncoder
92- from skorch .callbacks import GradientNormClipping
93- from skorch .dataset import ValidSplit
94- from torch .nn .utils .parametrize import register_parametrization
95-
96- from spd_learn .functional import (
97- get_epsilon ,
98- matrix_exp ,
99- matrix_inv_sqrt ,
100- matrix_log ,
101- matrix_power ,
102- matrix_sqrt_inv ,
103- )
104- from spd_learn .models import TSMNet
105- from spd_learn .modules import LogEig
106- from spd_learn .modules .manifold import SymmetricPositiveDefinite
10785
10886warnings .filterwarnings ("ignore" )
10987
11593# These will be included in a future release of ``spd_learn.functional``.
11694#
11795
96+ from spd_learn .functional import (
97+ get_epsilon ,
98+ matrix_exp ,
99+ matrix_log ,
100+ matrix_power ,
101+ matrix_sqrt_inv ,
102+ )
103+
118104
119105def geodesic_transport_to_identity (X , mean , t ):
120106 r"""Transport SPD matrices along the geodesic toward identity.
@@ -204,6 +190,11 @@ def karcher_mean(X, max_iter=50, return_distances=False):
204190# - **Target domain**: Session B (adaptation without labels)
205191#
206192
193+ from braindecode .datasets import create_from_X_y
194+ from moabb .datasets import BNCI2015_001
195+ from moabb .paradigms import MotorImagery
196+ from sklearn .preprocessing import LabelEncoder
197+
207198dataset = BNCI2015_001 ()
208199paradigm = MotorImagery (
209200 n_classes = 2 ,
@@ -320,6 +311,12 @@ def karcher_mean(X, max_iter=50, return_distances=False):
320311# - **Validation split** (10%) for early stopping
321312#
322313
314+ from braindecode import EEGClassifier
315+ from skorch .callbacks import GradientNormClipping
316+ from skorch .dataset import ValidSplit
317+
318+ from spd_learn .models import TSMNet
319+
323320n_chans = X_source .shape [1 ]
324321n_outputs = len (le .classes_ )
325322
@@ -366,6 +363,8 @@ def karcher_mean(X, max_iter=50, return_distances=False):
366363# Evaluate the source-trained model on target domain without adaptation.
367364#
368365
366+ from sklearn .metrics import balanced_accuracy_score
367+
369368underlying_model = clf .module_
370369
371370y_pred_source = clf .predict (X_source )
@@ -391,6 +390,8 @@ def karcher_mean(X, max_iter=50, return_distances=False):
391390# rebiasing from the trained BN layer.
392391#
393392
393+ from spd_learn .modules import LogEig
394+
394395
395396def extract_spd_features (model , X_data , batch_size = 32 ):
396397 """Extract SPD features before batch normalization."""
@@ -491,6 +492,8 @@ def refit_spdbn_karcher(model, X_data, batch_size=32):
491492# while maintaining class diversity (high marginal entropy).
492493#
493494
495+ import torch .nn .functional as F
496+
494497
495498def im_loss (logits , temperature = 2.0 ):
496499 """Information Maximization loss (matching SPDIM paper)."""
@@ -521,6 +524,11 @@ def im_loss(logits, temperature=2.0):
521524# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
522525#
523526
527+ from torch .nn .utils .parametrize import register_parametrization
528+
529+ from spd_learn .functional import matrix_inv_sqrt
530+ from spd_learn .modules .manifold import SymmetricPositiveDefinite
531+
524532
525533class SPDLearnableRecenter (torch .nn .Module ):
526534 def __init__ (
0 commit comments