Skip to content

Commit e0a1324

Browse files
Fix pre-commit issues and reduce epochs for CI
- Add trailing newline to references.bib - Fix import sorting (ruff I001) in SPDIM example - Apply ruff formatting to SPDIM example - Reduce TSMNet and SPDIM(bias) epochs from 200 to 30 for faster documentation build
1 parent 356962f commit e0a1324

2 files changed

Lines changed: 35 additions & 28 deletions

File tree

docs/source/references.bib

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1067,4 +1067,4 @@ @inproceedings{
10671067
booktitle={The Thirteenth International Conference on Learning Representations},
10681068
year={2025},
10691069
url={https://openreview.net/forum?id=CoQw1dXtGb}
1070-
}
1070+
}

examples/applied_examples/plot_source_free_domain.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
import numpy as np
8484
import torch
8585

86+
8687
warnings.filterwarnings("ignore")
8788

8889
######################################################################
@@ -132,9 +133,7 @@ def karcher_mean(X, max_iter=50, return_distances=False):
132133
if n_samples == 1:
133134
mean = X[:1]
134135
if return_distances:
135-
return mean, torch.zeros(
136-
X.shape[:-2], dtype=X.dtype, device=X.device
137-
)
136+
return mean, torch.zeros(X.shape[:-2], dtype=X.dtype, device=X.device)
138137
return mean
139138

140139
w = torch.ones((*X.shape[:-2], 1, 1), dtype=X.dtype, device=X.device)
@@ -195,6 +194,7 @@ def karcher_mean(X, max_iter=50, return_distances=False):
195194
from moabb.paradigms import MotorImagery
196195
from sklearn.preprocessing import LabelEncoder
197196

197+
198198
dataset = BNCI2015_001()
199199
paradigm = MotorImagery(
200200
n_classes=2,
@@ -241,7 +241,10 @@ def karcher_mean(X, max_iter=50, return_distances=False):
241241

242242
# Create braindecode dataset for target domain
243243
target_ds = create_from_X_y(
244-
X_target, y_target, drop_last_window=True, sfreq=sfreq,
244+
X_target,
245+
y_target,
246+
drop_last_window=True,
247+
sfreq=sfreq,
245248
)
246249

247250
print(f"Dataset: {dataset.code}")
@@ -272,20 +275,23 @@ def karcher_mean(X, max_iter=50, return_distances=False):
272275
# SPDIM protocol: subsample all classes except the last to ratio_level
273276
rng = np.random.RandomState(42)
274277
classes = sorted(np.unique(y_target))
275-
subsample_inds = np.sort(np.concatenate([
276-
rng.choice(
277-
np.flatnonzero(y_target == c),
278-
size=math.ceil(np.sum(y_target == c) * (
279-
ratio_level if i < len(classes) - 1 else 1.0
280-
)),
281-
replace=False,
278+
subsample_inds = np.sort(
279+
np.concatenate(
280+
[
281+
rng.choice(
282+
np.flatnonzero(y_target == c),
283+
size=math.ceil(
284+
np.sum(y_target == c)
285+
* (ratio_level if i < len(classes) - 1 else 1.0)
286+
),
287+
replace=False,
288+
)
289+
for i, c in enumerate(classes)
290+
]
282291
)
283-
for i, c in enumerate(classes)
284-
]))
292+
)
285293

286-
target_shifted_ds = target_ds.split(
287-
by={"shifted": subsample_inds.tolist()}
288-
)["shifted"]
294+
target_shifted_ds = target_ds.split(by={"shifted": subsample_inds.tolist()})["shifted"]
289295

290296
# Keep arrays for SPDIM adaptation methods
291297
X_target_shifted = X_target[subsample_inds]
@@ -295,10 +301,7 @@ def karcher_mean(X, max_iter=50, return_distances=False):
295301
print(f" Target samples: {len(target_shifted_ds)}")
296302
for c in np.unique(y_target_shifted):
297303
n = (y_target_shifted == c).sum()
298-
print(
299-
f" Class {le.classes_[c]}: {n} "
300-
f"({100 * n / len(y_target_shifted):.0f}%)"
301-
)
304+
print(f" Class {le.classes_[c]}: {n} ({100 * n / len(y_target_shifted):.0f}%)")
302305

303306
######################################################################
304307
# Training TSMNet on Source Domain
@@ -317,6 +320,7 @@ def karcher_mean(X, max_iter=50, return_distances=False):
317320

318321
from spd_learn.models import TSMNet
319322

323+
320324
n_chans = X_source.shape[1]
321325
n_outputs = len(le.classes_)
322326

@@ -343,7 +347,7 @@ def karcher_mean(X, max_iter=50, return_distances=False):
343347
optimizer__weight_decay=1e-4,
344348
train_split=ValidSplit(0.1, stratified=True, random_state=42),
345349
batch_size=32,
346-
max_epochs=200,
350+
max_epochs=30, # Reduced from 200 for faster documentation build
347351
callbacks=[
348352
("gradient_clip", GradientNormClipping(gradient_clip_value=5.0)),
349353
],
@@ -365,6 +369,7 @@ def karcher_mean(X, max_iter=50, return_distances=False):
365369

366370
from sklearn.metrics import balanced_accuracy_score
367371

372+
368373
underlying_model = clf.module_
369374

370375
y_pred_source = clf.predict(X_source)
@@ -571,24 +576,24 @@ def forward(self, input):
571576
with torch.no_grad():
572577
S_init = matrix_log.apply(target_karcher_mean.squeeze(0).unsqueeze(0))
573578

574-
X_spd_target = extract_spd_features(
575-
underlying_model, X_target_shifted, batch_size=32
576-
)
579+
X_spd_target = extract_spd_features(underlying_model, X_target_shifted, batch_size=32)
577580

578581
print(f"Log-space parameter S initialized. Shape: {target_karcher_mean.shape}")
579582

580583
adapter = SPDLearnableRecenter(target_karcher_mean.shape[-1])
581584
adapter.bias = target_karcher_mean.clone()
582585

583586
optimizer_bias = torch.optim.Adam(adapter.parameters(), lr=0.05)
584-
n_epochs_bias = 200
587+
n_epochs_bias = 30 # Reduced from 200 for faster documentation build
585588
losses_bias = []
586589
best_loss_bias = float("inf")
587590
best_S = target_karcher_mean.clone().detach()
588591
for epoch in range(n_epochs_bias):
589592
optimizer_bias.zero_grad()
590593
logits = spdim_forward(
591-
underlying_model, X_spd_target, adapter,
594+
underlying_model,
595+
X_spd_target,
596+
adapter,
592597
)
593598
loss = im_loss(logits, temperature=2.0)
594599
loss.backward()
@@ -607,7 +612,9 @@ def forward(self, input):
607612
with torch.no_grad():
608613
adapter.bias = best_S
609614
logits = spdim_forward(
610-
underlying_model, X_spd_target, adapter,
615+
underlying_model,
616+
X_spd_target,
617+
adapter,
611618
)
612619
y_pred_bias = logits.argmax(dim=1).cpu().numpy()
613620

0 commit comments

Comments
 (0)