8383import numpy as np
8484import torch
8585
86+
8687warnings .filterwarnings ("ignore" )
8788
8889######################################################################
@@ -132,9 +133,7 @@ def karcher_mean(X, max_iter=50, return_distances=False):
132133 if n_samples == 1 :
133134 mean = X [:1 ]
134135 if return_distances :
135- return mean , torch .zeros (
136- X .shape [:- 2 ], dtype = X .dtype , device = X .device
137- )
136+ return mean , torch .zeros (X .shape [:- 2 ], dtype = X .dtype , device = X .device )
138137 return mean
139138
140139 w = torch .ones ((* X .shape [:- 2 ], 1 , 1 ), dtype = X .dtype , device = X .device )
@@ -195,6 +194,7 @@ def karcher_mean(X, max_iter=50, return_distances=False):
195194from moabb .paradigms import MotorImagery
196195from sklearn .preprocessing import LabelEncoder
197196
197+
198198dataset = BNCI2015_001 ()
199199paradigm = MotorImagery (
200200 n_classes = 2 ,
@@ -241,7 +241,10 @@ def karcher_mean(X, max_iter=50, return_distances=False):
241241
242242# Create braindecode dataset for target domain
243243target_ds = create_from_X_y (
244- X_target , y_target , drop_last_window = True , sfreq = sfreq ,
244+ X_target ,
245+ y_target ,
246+ drop_last_window = True ,
247+ sfreq = sfreq ,
245248)
246249
247250print (f"Dataset: { dataset .code } " )
@@ -272,20 +275,23 @@ def karcher_mean(X, max_iter=50, return_distances=False):
272275# SPDIM protocol: subsample all classes except the last to ratio_level
273276rng = np .random .RandomState (42 )
274277classes = sorted (np .unique (y_target ))
275- subsample_inds = np .sort (np .concatenate ([
276- rng .choice (
277- np .flatnonzero (y_target == c ),
278- size = math .ceil (np .sum (y_target == c ) * (
279- ratio_level if i < len (classes ) - 1 else 1.0
280- )),
281- replace = False ,
278+ subsample_inds = np .sort (
279+ np .concatenate (
280+ [
281+ rng .choice (
282+ np .flatnonzero (y_target == c ),
283+ size = math .ceil (
284+ np .sum (y_target == c )
285+ * (ratio_level if i < len (classes ) - 1 else 1.0 )
286+ ),
287+ replace = False ,
288+ )
289+ for i , c in enumerate (classes )
290+ ]
282291 )
283- for i , c in enumerate (classes )
284- ]))
292+ )
285293
286- target_shifted_ds = target_ds .split (
287- by = {"shifted" : subsample_inds .tolist ()}
288- )["shifted" ]
294+ target_shifted_ds = target_ds .split (by = {"shifted" : subsample_inds .tolist ()})["shifted" ]
289295
290296# Keep arrays for SPDIM adaptation methods
291297X_target_shifted = X_target [subsample_inds ]
@@ -295,10 +301,7 @@ def karcher_mean(X, max_iter=50, return_distances=False):
295301print (f" Target samples: { len (target_shifted_ds )} " )
296302for c in np .unique (y_target_shifted ):
297303 n = (y_target_shifted == c ).sum ()
298- print (
299- f" Class { le .classes_ [c ]} : { n } "
300- f"({ 100 * n / len (y_target_shifted ):.0f} %)"
301- )
304+ print (f" Class { le .classes_ [c ]} : { n } ({ 100 * n / len (y_target_shifted ):.0f} %)" )
302305
303306######################################################################
304307# Training TSMNet on Source Domain
@@ -317,6 +320,7 @@ def karcher_mean(X, max_iter=50, return_distances=False):
317320
318321from spd_learn .models import TSMNet
319322
323+
320324n_chans = X_source .shape [1 ]
321325n_outputs = len (le .classes_ )
322326
@@ -343,7 +347,7 @@ def karcher_mean(X, max_iter=50, return_distances=False):
343347 optimizer__weight_decay = 1e-4 ,
344348 train_split = ValidSplit (0.1 , stratified = True , random_state = 42 ),
345349 batch_size = 32 ,
346- max_epochs = 200 ,
350+ max_epochs = 30 , # Reduced from 200 for faster documentation build
347351 callbacks = [
348352 ("gradient_clip" , GradientNormClipping (gradient_clip_value = 5.0 )),
349353 ],
@@ -365,6 +369,7 @@ def karcher_mean(X, max_iter=50, return_distances=False):
365369
366370from sklearn .metrics import balanced_accuracy_score
367371
372+
368373underlying_model = clf .module_
369374
370375y_pred_source = clf .predict (X_source )
@@ -571,24 +576,24 @@ def forward(self, input):
571576with torch .no_grad ():
572577 S_init = matrix_log .apply (target_karcher_mean .squeeze (0 ).unsqueeze (0 ))
573578
574- X_spd_target = extract_spd_features (
575- underlying_model , X_target_shifted , batch_size = 32
576- )
579+ X_spd_target = extract_spd_features (underlying_model , X_target_shifted , batch_size = 32 )
577580
578581print (f"Log-space parameter S initialized. Shape: { target_karcher_mean .shape } " )
579582
580583adapter = SPDLearnableRecenter (target_karcher_mean .shape [- 1 ])
581584adapter .bias = target_karcher_mean .clone ()
582585
583586optimizer_bias = torch .optim .Adam (adapter .parameters (), lr = 0.05 )
584- n_epochs_bias = 200
587+ n_epochs_bias = 30 # Reduced from 200 for faster documentation build
585588losses_bias = []
586589best_loss_bias = float ("inf" )
587590best_S = target_karcher_mean .clone ().detach ()
588591for epoch in range (n_epochs_bias ):
589592 optimizer_bias .zero_grad ()
590593 logits = spdim_forward (
591- underlying_model , X_spd_target , adapter ,
594+ underlying_model ,
595+ X_spd_target ,
596+ adapter ,
592597 )
593598 loss = im_loss (logits , temperature = 2.0 )
594599 loss .backward ()
@@ -607,7 +612,9 @@ def forward(self, input):
607612with torch .no_grad ():
608613 adapter .bias = best_S
609614 logits = spdim_forward (
610- underlying_model , X_spd_target , adapter ,
615+ underlying_model ,
616+ X_spd_target ,
617+ adapter ,
611618 )
612619 y_pred_bias = logits .argmax (dim = 1 ).cpu ().numpy ()
613620
0 commit comments