From b9932ed5b6b955f1b74b9e1e39130b7e3daa418e Mon Sep 17 00:00:00 2001 From: Anton Date: Tue, 16 Jun 2026 08:43:09 +0200 Subject: [PATCH 1/5] Fixes a problem with windows absolute download path being corrupted by _sanitize_path --- moabb/datasets/download.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/moabb/datasets/download.py b/moabb/datasets/download.py index 95826e91b..d8aa64d17 100644 --- a/moabb/datasets/download.py +++ b/moabb/datasets/download.py @@ -41,8 +41,16 @@ def _set_user_agent(downloader): def _sanitize_path(path: Path) -> Path: + path = Path(path) table = {ord(c): "-" for c in ':*?"<>|'} - return Path(str(path).translate(table)) + + if path.anchor: + return Path( + path.anchor, + *(part.translate(table) for part in path.parts[1:]), + ) + + return Path(*(part.translate(table) for part in path.parts)) def _normalize_destination(url: str, root: Path) -> Path: From 9d88069448f86880926dbd18967814ba88ae2ab1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jun 2026 06:49:15 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks --- moabb/datasets/download.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/moabb/datasets/download.py b/moabb/datasets/download.py index d8aa64d17..0423c3ca9 100644 --- a/moabb/datasets/download.py +++ b/moabb/datasets/download.py @@ -45,10 +45,7 @@ def _sanitize_path(path: Path) -> Path: table = {ord(c): "-" for c in ':*?"<>|'} if path.anchor: - return Path( - path.anchor, - *(part.translate(table) for part in path.parts[1:]), - ) + return Path(path.anchor, *(part.translate(table) for part in path.parts[1:])) return Path(*(part.translate(table) for part in path.parts)) From 551e911dd95f9dc46f6ac6c3a1ab30a615067694 Mon Sep 17 00:00:00 2001 From: Anton Date: Tue, 16 Jun 2026 08:57:42 +0200 Subject: [PATCH 3/5] Updated for - PR 1079. --- docs/source/whats_new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index 646f99ea8..3c98de47f 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -62,7 +62,7 @@ Bugs - Fix EEG layout corruption in :class:`moabb.datasets.BNCI2020_002`: the F-contiguous ``bciexp.data`` was reshaped in default C-order, producing a trial-fastest interleaved layout that disagreed with the per-trial stim markers and made every epoch sample the wrong trial. The reshape now transposes to trial-major before flattening (by `Bruno Aristimunha`_). - Fix ``stim_trial`` content in :class:`moabb.datasets.MartinezCagigal2023Checker` and :class:`moabb.datasets.MartinezCagigal2023Pary`: the channel was carrying the per-recording trial index instead of the attended command id, breaking multiclass classification across recordings. The marker is now the command id (resolved via the new :func:`moabb.datasets.utils.resolve_cvep_command_ids` helper), and the ``_trial_meta`` annotation extras gain a ``command_id`` key alongside ``trial_id`` (by `Bruno Aristimunha`_). - Cache Figshare's file listing in :func:`moabb.datasets.download.fs_get_file_list` (process-level ``lru_cache``) and persist it on disk next to the data for MAMEM (:class:`moabb.datasets.MAMEM1`/``MAMEM2``/``MAMEM3``). Once a dataset has been downloaded, subsequent calls never contact Figshare; pass ``force_update=True`` to bypass both layers (by `Bruno Aristimunha`_). - +- Fix Windows download path sanitization that changed absolute paths like ``C:\data`` into relative ``C-\data`` paths (:gh:`1079` by `Anton Andreev`_). Code health ~~~~~~~~~~~ - None yet. From 6d000f75bfb3dfcb261999b3775d89615070bdeb Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 19 Jun 2026 10:49:27 +0200 Subject: [PATCH 4/5] The new CrossSubjectTargetAwareEvaluation that supports 6 modes of evaluation. And an example on how to use 4 of them using RPA (Riemannian Procrustes Analysis). --- ...cross_subject_transfer_learning_example.py | 631 ++++++++++++ .../cross_subject_target_aware_evaluation.py | 895 ++++++++++++++++++ 2 files changed, 1526 insertions(+) create mode 100644 examples/how_to_benchmark/cross_subject_transfer_learning_example.py create mode 100644 moabb/evaluations/cross_subject_target_aware_evaluation.py diff --git a/examples/how_to_benchmark/cross_subject_transfer_learning_example.py b/examples/how_to_benchmark/cross_subject_transfer_learning_example.py new file mode 100644 index 000000000..b1643b1af --- /dev/null +++ b/examples/how_to_benchmark/cross_subject_transfer_learning_example.py @@ -0,0 +1,631 @@ +""" +Example of the transfer-learning-oriented +CrossSubjectTargetAwareEvaluation. + +This example compares a standard TS + LR pipeline with an RPA + TS + LR +pipeline. Riemannian Procrustes Alignment is used here as a simple example of +a target-aware transfer-learning method: it can use source-subject structure and, +when allowed by the evaluation mode, unlabeled target covariance data for +alignment. + +The script can demonstrate 4 of the 6 available modes by changing `cs_mode`: + +* HOS_SOURCE_ONLY_BLOCKWISE: + No target data is used during adaptation. RPA aligns only the source + subjects, so this should usually be weaker than target-adaptive modes. + +* HOS_UNLABELED_20P: + The first 20% of the held-out target subject is provided without labels + for target-domain alignment. The remaining 80% is evaluated. + +* HOS_UNLABELED_50P: + The first 50% of the held-out target subject is provided without labels + for target-domain alignment. The remaining 50% is evaluated. + +* HOS_LABELED_20P: + The first 20% of the held-out target subject is provided as labeled + calibration data. In this example, RPA ignores the labels and uses only + the covariance distribution for alignment, so HOS_LABELED_20P and + HOS_UNLABELED_20P should give very similar or identical results for the + RPA step. + +The example does not demonstrate HOS_SOURCE_ONLY_TRIALWISE or +HOS_UNLABELED_100P. +""" + +from __future__ import annotations + +import warnings +from typing import Dict, Optional + +import numpy as np +import pandas as pd + +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.linear_model import LogisticRegression +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler + +from pyriemann.estimation import Covariances +from pyriemann.tangentspace import TangentSpace +from pyriemann.utils.mean import mean_riemann + +from moabb.datasets import BNCI2014_001, Weibo2014 +from moabb.paradigms import LeftRightImagery + +from moabb.evaluations.cross_subject_target_aware_evaluation import ( + CrossSubjectTargetAwareEvaluation, + CsMode, +) + +# --------------------------------------------------------------------- +# SPD helpers +# --------------------------------------------------------------------- +def symmetrize(A: np.ndarray) -> np.ndarray: + return 0.5 * (A + A.T) + + +def nearest_spd_jitter(C: np.ndarray, eps: float = 1e-7) -> np.ndarray: + C = symmetrize(np.asarray(C, dtype=float)) + vals, vecs = np.linalg.eigh(C) + vals = np.maximum(vals, eps) + return symmetrize((vecs * vals) @ vecs.T) + + +def safe_mean_riemann(X: np.ndarray) -> np.ndarray: + X = np.asarray(X) + + if X.ndim != 3 or len(X) == 0: + raise ValueError( + "safe_mean_riemann expects shape " + "(n_matrices, n_channels, n_channels)." + ) + + X = np.stack([nearest_spd_jitter(C) for C in X], axis=0) + return nearest_spd_jitter(mean_riemann(X)) + + +def powerm_spd(C: np.ndarray, power: float, eps: float = 1e-12) -> np.ndarray: + C = nearest_spd_jitter(C) + vals, vecs = np.linalg.eigh(C) + vals = np.maximum(vals, eps) + return symmetrize((vecs * (vals ** power)) @ vecs.T) + + +def batch_congruence(X: np.ndarray, A: np.ndarray) -> np.ndarray: + X = np.asarray(X) + + return np.stack( + [nearest_spd_jitter(A @ C @ A.T) for C in X], + axis=0, + ) + + +# --------------------------------------------------------------------- +# Riemannian Procrustes Alignment transformer +# --------------------------------------------------------------------- +class RiemannianProcrustesAlignment(BaseEstimator, TransformerMixin): + """ + Riemannian Procrustes Alignment for covariance matrices. + + Parameters + ---------- + reference_subject : str or None, default="auto" + Reference domain used for alignment. + + - "auto": choose the source subject whose mean covariance is closest + to the global source mean. + - "global" or None: use the global source mean directly. + - any other value: use that source subject as the reference subject. + The value is converted to str and must exist among the source + subjects seen during fit. + + alignment_strength : float, default=1.0 + Strength of the alignment transform. + + - 0.0: no alignment. + - 1.0: full Procrustes-style alignment. + - values between 0.0 and 1.0: partial alignment. + + The transform is computed as: + A = M_ref ** (0.5 * alignment_strength) + @ M_domain ** (-0.5 * alignment_strength) + + use_global_transform_for_unseen : bool, default=True + If True, unseen data without an explicit target transform is aligned + using the transform from the global source mean to the reference mean. + If False, unseen data is returned after SPD cleanup only. + + verbose : bool, default=False + If True, print diagnostic information during fit, including the + selected reference subject, number of source subjects, and whether + target data was used. + """ + + def __init__( + self, + reference_subject: Optional[str] = "auto", + alignment_strength: float = 1.0, + use_global_transform_for_unseen: bool = True, + verbose: bool = False, + ): + self.reference_subject = reference_subject + self.alignment_strength = alignment_strength + self.use_global_transform_for_unseen = use_global_transform_for_unseen + self.verbose = verbose + + def _make_transform(self, M_ref: np.ndarray, M_source: np.ndarray) -> np.ndarray: + alpha = float(self.alignment_strength) + + if alpha == 0.0: + return np.eye(M_ref.shape[0]) + + return ( + powerm_spd(M_ref, 0.5 * alpha) + @ powerm_spd(M_source, -0.5 * alpha) + ) + + def fit( + self, + X, + y=None, + subjects=None, + cs_mode=None, + X_target_unlabeled=None, + X_target_labeled=None, + y_target_labeled=None, + **fit_params, + ): + """ + Input + ----- + X : ndarray, shape (n_trials, n_channels, n_channels) + + Fit metadata + ------------ + subjects : array-like + Source subject ID for each source training trial. + + cs_mode : str, optional + Cross-subject evaluation mode name, for example: + "HOS_SOURCE_ONLY_BLOCKWISE", "HOS_UNLABELED_20P", + "HOS_LABELED_20P". + + X_target_unlabeled : ndarray, optional + Unlabeled covariance matrices from the held-out target subject. + If provided, RPA uses them to estimate the target-domain alignment + transform. + + X_target_labeled : ndarray, optional + Labeled covariance matrices from the held-out target subject. + If provided and no unlabeled target data is provided, RPA uses + their covariance distribution for alignment. The labels themselves + are not used. + + y_target_labeled : ndarray, optional + Labels for X_target_labeled. RPA accepts this argument for + evaluator compatibility but does not use the labels directly. + """ + X = np.asarray(X) + + if X.ndim != 3: + raise ValueError( + "RiemannianProcrustesAlignment expects covariance matrices " + f"with shape (n_trials, n_channels, n_channels). Got {X.shape}." + ) + + if X.shape[1] != X.shape[2]: + raise ValueError("Covariance matrices must be square.") + + self.cs_mode_ = cs_mode + self.n_channels_ = int(X.shape[1]) + + if subjects is None: + warnings.warn( + "subjects was not provided to RPA.fit(). " + "Using one global source domain only.", + RuntimeWarning, + ) + subjects = np.array(["source"] * len(X), dtype=str) + else: + subjects = np.asarray(subjects).astype(str) + + if len(subjects) != len(X): + raise ValueError("X and subjects must have the same length.") + + self.source_subjects_ = np.unique(subjects).astype(str) + self.source_means_: Dict[str, np.ndarray] = {} + + for s in self.source_subjects_: + idx = subjects == s + self.source_means_[str(s)] = safe_mean_riemann(X[idx]) + + self.global_source_mean_ = safe_mean_riemann(X) + + # Choose reference domain. + if self.reference_subject is None or self.reference_subject == "global": + self.reference_subject_ = "global" + self.reference_mean_ = self.global_source_mean_ + + elif self.reference_subject == "auto": + distances = {} + + for s, M_s in self.source_means_.items(): + distances[s] = float(np.linalg.norm(M_s - self.global_source_mean_)) + + self.reference_subject_ = min(distances, key=distances.get) + self.reference_mean_ = self.source_means_[self.reference_subject_] + + else: + self.reference_subject_ = str(self.reference_subject) + + if self.reference_subject_ not in self.source_means_: + raise ValueError( + f"reference_subject={self.reference_subject_!r} is not " + f"in source subjects {list(self.source_means_.keys())}." + ) + + self.reference_mean_ = self.source_means_[self.reference_subject_] + + # Source-subject transforms. + self.source_transforms_: Dict[str, np.ndarray] = {} + + for s, M_s in self.source_means_.items(): + self.source_transforms_[s] = self._make_transform( + self.reference_mean_, + M_s, + ) + + self.global_transform_ = self._make_transform( + self.reference_mean_, + self.global_source_mean_, + ) + + # The evaluation mode decides which target data is provided. + # Prefer explicit unlabeled target data when available. Otherwise, + # use labeled target covariance data for alignment, but ignore labels. + X_target_for_alignment = None + target_source_kind = "none" + + if X_target_unlabeled is not None and len(X_target_unlabeled) > 0: + X_target_for_alignment = X_target_unlabeled + target_source_kind = "unlabeled" + + elif X_target_labeled is not None and len(X_target_labeled) > 0: + X_target_for_alignment = X_target_labeled + target_source_kind = "labeled" + + self.target_transform_ = None + self.has_target_data_ = False + self.target_source_kind_ = target_source_kind + + if X_target_for_alignment is not None: + X_target_for_alignment = np.asarray(X_target_for_alignment) + + if X_target_for_alignment.ndim != 3: + raise ValueError( + "Target data for RPA must have shape " + "(n_trials, n_channels, n_channels)." + ) + + if ( + X_target_for_alignment.shape[1] != self.n_channels_ + or X_target_for_alignment.shape[2] != self.n_channels_ + ): + raise ValueError( + "Target covariance matrices have incompatible shape. " + f"Expected ({self.n_channels_}, {self.n_channels_}), " + f"got {X_target_for_alignment.shape[1:]}." + ) + + target_mean = safe_mean_riemann(X_target_for_alignment) + + self.target_transform_ = self._make_transform( + self.reference_mean_, + target_mean, + ) + self.has_target_data_ = True + + # Stored for sklearn Pipeline fit_transform. + self._fit_subjects_ = subjects.copy() + self._n_fit_samples_ = int(len(X)) + + if self.verbose: + n_unlab = 0 if X_target_unlabeled is None else len(X_target_unlabeled) + n_lab = 0 if X_target_labeled is None else len(X_target_labeled) + + print( + "RPA.fit | " + f"cs_mode={self.cs_mode_}, " + f"reference_subject={self.reference_subject_}, " + f"n_source_subjects={len(self.source_subjects_)}, " + f"alignment_strength={self.alignment_strength}, " + f"has_target_data={self.has_target_data_}, " + f"target_source_kind={self.target_source_kind_}, " + f"n_target_unlabeled={n_unlab}, " + f"n_target_labeled={n_lab}", + flush=True, + ) + + return self + + def transform(self, X, subjects=None): + self._check_is_fitted() + + X = np.asarray(X) + + if X.ndim != 3: + raise ValueError( + "RiemannianProcrustesAlignment expects covariance matrices " + f"with shape (n_trials, n_channels, n_channels). Got {X.shape}." + ) + + if X.shape[1] != self.n_channels_ or X.shape[2] != self.n_channels_: + raise ValueError( + f"Expected matrices with shape ({self.n_channels_}, " + f"{self.n_channels_}), got {X.shape[1:]}." + ) + + # Explicit source-subject transform. + if subjects is not None: + subjects = np.asarray(subjects).astype(str) + + if len(subjects) != len(X): + raise ValueError("X and subjects must have the same length.") + + X_out = np.empty_like(X, dtype=float) + + for s in np.unique(subjects): + s = str(s) + idx = subjects == s + + if s not in self.source_transforms_: + raise ValueError( + f"Unknown source subject {s!r}. Available: " + f"{list(self.source_transforms_.keys())}" + ) + + X_out[idx] = batch_congruence(X[idx], self.source_transforms_[s]) + + return X_out + + # sklearn Pipeline training transform: no subjects are passed to + # transform(), but the length matches the fitted training data. + if hasattr(self, "_fit_subjects_") and len(X) == self._n_fit_samples_: + X_out = np.empty_like(X, dtype=float) + + for s in np.unique(self._fit_subjects_): + s = str(s) + idx = self._fit_subjects_ == s + X_out[idx] = batch_congruence(X[idx], self.source_transforms_[s]) + + return X_out + + # Test / unseen data. + if self.target_transform_ is not None: + return batch_congruence(X, self.target_transform_) + + if self.use_global_transform_for_unseen: + return batch_congruence(X, self.global_transform_) + + return np.stack([nearest_spd_jitter(C) for C in X], axis=0) + + def fit_transform( + self, + X, + y=None, + subjects=None, + cs_mode=None, + X_target_unlabeled=None, + X_target_labeled=None, + y_target_labeled=None, + **fit_params, + ): + return self.fit( + X, + y=y, + subjects=subjects, + cs_mode=cs_mode, + X_target_unlabeled=X_target_unlabeled, + X_target_labeled=X_target_labeled, + y_target_labeled=y_target_labeled, + **fit_params, + ).transform(X, subjects=subjects) + + def _check_is_fitted(self): + if not hasattr(self, "reference_mean_"): + raise RuntimeError("RiemannianProcrustesAlignment is not fitted.") +# --------------------------------------------------------------------- +# Demo pipelines +# --------------------------------------------------------------------- + +def make_pipelines(): + ts_lr = make_pipeline( + Covariances(estimator="oas"), + TangentSpace(metric="riemann"), + StandardScaler(), + LogisticRegression( + C=1.0, + class_weight="balanced", + max_iter=5000, + random_state=42, + ), + ) + + rpa_ts_lr = make_pipeline( + Covariances(estimator="oas"), + + RiemannianProcrustesAlignment( + reference_subject="auto", + alignment_strength=1.0, + use_global_transform_for_unseen=True, + verbose=True, + ), + + TangentSpace(metric="riemann"), + StandardScaler(), + LogisticRegression( + C=1.0, + class_weight="balanced", + max_iter=5000, + random_state=42, + ), + ) + + return { + "TS + LR": ts_lr, + "RPA + TS + LR": rpa_ts_lr, + } + + +# --------------------------------------------------------------------- +# Result summaries +# --------------------------------------------------------------------- +def normalize_results(results: pd.DataFrame) -> pd.DataFrame: + results = results.copy() + + if "dataset" in results.columns: + results["dataset"] = results["dataset"].apply( + lambda d: d.code if hasattr(d, "code") else str(d) + ) + + return results + + +def summarize_per_dataset_pipeline(results: pd.DataFrame) -> pd.DataFrame: + """ + One row per dataset and pipeline. + """ + results = normalize_results(results) + + summary = ( + results + .groupby(["dataset", "pipeline"], as_index=False) + .agg( + n_folds=("score", "count"), + mean_ROC_AUC=("score", "mean"), + std_ROC_AUC=("score", "std"), + ) + .sort_values(["dataset", "mean_ROC_AUC"], ascending=[True, False]) + .reset_index(drop=True) + ) + + return summary + + +def summarize_global_pipeline(summary: pd.DataFrame) -> pd.DataFrame: + """ + One row per pipeline. + + The global mean is computed over dataset means, not over all folds. + This avoids giving more weight to datasets with more subjects/sessions. + """ + global_summary = ( + summary + .groupby("pipeline", as_index=False) + .agg( + n_datasets=("dataset", "nunique"), + total_folds=("n_folds", "sum"), + mean_ROC_AUC_over_datasets=("mean_ROC_AUC", "mean"), + std_ROC_AUC_over_datasets=("mean_ROC_AUC", "std"), + ) + .sort_values("mean_ROC_AUC_over_datasets", ascending=False) + .reset_index(drop=True) + ) + + return global_summary + + +def make_dataset_pipeline_table(summary: pd.DataFrame) -> pd.DataFrame: + """ + Wide table: rows are datasets, columns are pipelines. + Useful for quick visual comparison. + """ + table = summary.pivot( + index="dataset", + columns="pipeline", + values="mean_ROC_AUC", + ).reset_index() + + table.columns.name = None + return table + +# --------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------- + + +def main(): + dataset = Weibo2014()#BNCI2014_001() + + # Uncomment for a fast smoke test. + # dataset.subject_list = [1, 2, 3, 4, 5] + + datasets = [dataset] + + paradigm = LeftRightImagery( + fmin=8.0, + fmax=32.0, + resample=128, + ) + + pipelines = make_pipelines() + + # Good RPA examples: + # + # CsMode.HOS_UNLABELED_20P: + # First 20% of the held-out target subject is used without labels + # for target alignment. Remaining 80% is evaluated. + # + # CsMode.HOS_LABELED_20P: + # First 20% of the held-out target subject is passed separately as + # labeled calibration data. RPA may use its covariance distribution + # for alignment; a final classifier may use the labels if it supports + # X_target_labeled/y_target_labeled. + # + # CsMode.HOS_SOURCE_ONLY_BLOCKWISE: + # No target adaptation. This is closest to standard MOABB + # cross-subject block prediction. + # + cs_mode = CsMode.HOS_UNLABELED_20P # RPA aligns using the first 20% target data without labels. + #cs_mode = CsMode.HOS_UNLABELED_50P # RPA aligns using the first 20% target data without labels. + #cs_mode = CsMode.HOS_LABELED_20P # RPA aligns using the first 20% target data, ignoring labels. + #cs_mode = CsMode.HOS_SOURCE_ONLY_BLOCKWISE # RPA uses source-subject alignment only; no target adaptation. + + evaluation = CrossSubjectTargetAwareEvaluation( + paradigm=paradigm, + datasets=datasets, + cs_mode=cs_mode, + n_jobs=4, # 1 while debugging verbose RPA output + overwrite=True, + random_state=42, + ) + + results = evaluation.process(pipelines=pipelines) + results = normalize_results(results) + + useful_cols = [ + "dataset", + "subject", + "session", + "pipeline", + "score", + ] + + useful_cols = [c for c in useful_cols if c in results.columns] + + print("\nRaw results:") + print(results[useful_cols].to_string(index=False)) + + summary = summarize_per_dataset_pipeline(results) + + print("\nPer-dataset / per-pipeline summary:") + print(summary.to_string(index=False)) + + global_summary = summarize_global_pipeline(summary) + + print("\nGlobal per-pipeline summary:") + print(global_summary.to_string(index=False)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/moabb/evaluations/cross_subject_target_aware_evaluation.py b/moabb/evaluations/cross_subject_target_aware_evaluation.py new file mode 100644 index 000000000..fb606d456 --- /dev/null +++ b/moabb/evaluations/cross_subject_target_aware_evaluation.py @@ -0,0 +1,895 @@ +""" +A cross-subject implementation specifically targeting transfer learning. + +It provides a leave-one-subject-out (LOSO) evaluation with controlled access +to the held-out subject. The held-out subject is the target subject being +evaluated. The code allows users to compare results in the domain of transfer +learning BCI. + +Currently, it provides 6 modes: + +HOS_SOURCE_ONLY_TRIALWISE - Source-only training; each held-out target + trial is predicted independently. This mode is + the closest to real-world BCI when we do not + have enough data. + +HOS_SOURCE_ONLY_BLOCKWISE - Source-only training; held-out target trials + are predicted as a block, matching standard + MOABB behavior. It is a compatibility mode + with CrossSubjectEvaluation. Do not use it; + use one of the other modes, as they are more + clearly defined. + +HOS_UNLABELED_20P - First 20% of held-out target trials are used unlabeled + for adaptation; the remaining 80% are evaluated. + +HOS_UNLABELED_50P - First 50% of held-out target trials are used unlabeled + for adaptation; the remaining 50% are evaluated. + +HOS_UNLABELED_100P - All held-out target trials are used unlabeled for + transductive adaptation and are also evaluated. + +HOS_LABELED_20P - First 20% of held-out target trials are used with labels + for supervised calibration; the remaining 80% are + evaluated. + +Important notes: + +1) HOS_SOURCE_ONLY_BLOCKWISE vs HOS_UNLABELED_100P + + HOS_SOURCE_ONLY_BLOCKWISE predicts the target-subject samples as a + block after source-only training. The held-out subject is not provided + during training. + + By contrast, HOS_UNLABELED_100P explicitly provides the entire + unlabeled target block during fit/adaptation before predicting it. + Therefore, HOS_SOURCE_ONLY_BLOCKWISE should be used with care, as it + can be misused if a pipeline delays training and uses the held-out + subject as training data, as in HOS_UNLABELED_100P. + +2) Labeled target data is not provided to old pipelines + + In HOS_LABELED_20P, target labeled data is only passed if the estimator + accepts X_target_labeled / y_target_labeled. This means that old/regular + pipelines continue not to receive any data from the held-out subject. + +3) Unlabeled target data also only works for special estimators + + HOS_UNLABEnbLED_20P/50P/100P modes work only if a pipeline step explicitly + accepts X_target_unlabeled. Regular pipelines will silently ignore the + target adaptation data, as they are unaware of it. + +4) Split order defines adaptation data + The “first 20%” depends on trial order. + +""" + +from __future__ import annotations + +import inspect +import time +import warnings +from contextlib import contextmanager +from enum import Enum, auto +from typing import Any, Optional + +import joblib +import numpy as np + +from sklearn.base import BaseEstimator, ClassifierMixin, clone +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import LabelEncoder + +from tqdm import tqdm + +from moabb.evaluations import CrossSubjectEvaluation, CrossSubjectSplitter +from moabb.evaluations.utils import _create_scorer, _ensure_fitted + + +@contextmanager +def tqdm_joblib(tqdm_object): + """ + Context manager to patch joblib to report into tqdm on batch completion. + """ + + class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): + def __call__(self, *args, **kwargs): + tqdm_object.update(n=self.batch_size) + return super().__call__(*args, **kwargs) + + old_callback = joblib.parallel.BatchCompletionCallBack + joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback + + try: + yield tqdm_object + finally: + joblib.parallel.BatchCompletionCallBack = old_callback + tqdm_object.close() + + +# --------------------------------------------------------------------- +# Target-aware cross-subject modes +# --------------------------------------------------------------------- + + +class CsMode(Enum): + HOS_SOURCE_ONLY_TRIALWISE = auto() # Source-only training; each held-out target trial is predicted independently. + HOS_SOURCE_ONLY_BLOCKWISE = auto() # Source-only training; held-out target trials are predicted as a block, matching standard MOABB behavior. + HOS_UNLABELED_20P = auto() # First 20% of held-out target trials are used unlabeled for adaptation; remaining 80% are evaluated. + HOS_UNLABELED_50P = auto() # First 50% of held-out target trials are used unlabeled for adaptation; remaining 50% are evaluated. + HOS_UNLABELED_100P = auto() # All held-out target trials are used unlabeled for transductive adaptation and also evaluated. + HOS_LABELED_20P = auto() # First 20% of held-out target trials are used with labels for supervised calibration; remaining 80% are evaluated. + + +def cs_mode_uses_unlabeled_target(mode: CsMode) -> bool: + return mode in { + CsMode.HOS_UNLABELED_20P, + CsMode.HOS_UNLABELED_50P, + CsMode.HOS_UNLABELED_100P, + } + + +def cs_mode_uses_labeled_target(mode: CsMode) -> bool: + return mode == CsMode.HOS_LABELED_20P + + +def cs_mode_target_fraction(mode: CsMode) -> float: + if mode in { + CsMode.HOS_SOURCE_ONLY_TRIALWISE, + CsMode.HOS_SOURCE_ONLY_BLOCKWISE, + }: + return 0.0 + + if mode == CsMode.HOS_UNLABELED_20P: + return 0.20 + + if mode == CsMode.HOS_UNLABELED_50P: + return 0.50 + + if mode == CsMode.HOS_UNLABELED_100P: + return 1.00 + + if mode == CsMode.HOS_LABELED_20P: + return 0.20 + + raise ValueError(f"Unknown CsMode: {mode!r}") + + +def split_target_for_cs_mode( + test_idx: np.ndarray, + cs_mode: CsMode, +) -> tuple[np.ndarray, np.ndarray]: + test_idx = np.asarray(test_idx, dtype=int) + + if len(test_idx) == 0: + raise ValueError("Empty held-out test index.") + + fraction = cs_mode_target_fraction(cs_mode) + + if fraction == 0.0: + return np.array([], dtype=int), test_idx + + if fraction == 1.0: + return test_idx, test_idx + + n_adapt = int(round(fraction * len(test_idx))) + n_adapt = max(1, n_adapt) + n_adapt = min(n_adapt, len(test_idx) - 1) + + target_adapt_idx = test_idx[:n_adapt] + eval_idx = test_idx[n_adapt:] + + return target_adapt_idx, eval_idx + + +# --------------------------------------------------------------------- +# Metadata-aware fitting utilities +# --------------------------------------------------------------------- + + +def estimator_accepts_argument( + estimator: Any, + method_name: str, + arg_name: str, +) -> bool: + """ + Return True if an estimator method explicitly accepts a metadata argument. + + This is the only compatibility layer kept here: metadata-aware estimators + may accept subjects, cs_mode, X_target_unlabeled, X_target_labeled, or + y_target_labeled, while normal sklearn estimators usually do not. + """ + method = getattr(estimator, method_name, None) + + if method is None: + return False + + try: + signature = inspect.signature(method) + except (TypeError, ValueError): + return False + + return arg_name in signature.parameters + + +def _safe_transform(step: Any, X: Any): + if X is None: + return None + + if not hasattr(step, "transform"): + return X + + return step.transform(X) + + +def _fit_transform_step_with_metadata( + step: Any, + X: Any, + y: np.ndarray, + subjects: Optional[np.ndarray] = None, + cs_mode: Optional[str] = None, + X_target_unlabeled: Optional[Any] = None, + X_target_labeled: Optional[Any] = None, + y_target_labeled: Optional[np.ndarray] = None, +): + fit_kwargs = {} + + if subjects is not None and estimator_accepts_argument(step, "fit", "subjects"): + fit_kwargs["subjects"] = subjects + + if cs_mode is not None and estimator_accepts_argument(step, "fit", "cs_mode"): + fit_kwargs["cs_mode"] = cs_mode + + if ( + X_target_unlabeled is not None + and estimator_accepts_argument(step, "fit", "X_target_unlabeled") + ): + fit_kwargs["X_target_unlabeled"] = X_target_unlabeled + + if ( + X_target_labeled is not None + and estimator_accepts_argument(step, "fit", "X_target_labeled") + ): + fit_kwargs["X_target_labeled"] = X_target_labeled + + if ( + y_target_labeled is not None + and estimator_accepts_argument(step, "fit", "y_target_labeled") + ): + fit_kwargs["y_target_labeled"] = y_target_labeled + + # Use fit_transform only for ordinary sklearn-style steps without metadata. + # When metadata is passed, call fit() and transform() explicitly because + # fit_transform() may be implemented separately and may not accept the same + # metadata arguments as fit(). + if hasattr(step, "fit_transform") and not fit_kwargs: + Xt = step.fit_transform(X, y) + else: + step.fit(X, y, **fit_kwargs) + + if not hasattr(step, "transform"): + raise TypeError( + f"Intermediate pipeline step {step!r} has no transform method." + ) + + Xt = step.transform(X) + + Xt_target_unlabeled = _safe_transform(step, X_target_unlabeled) + Xt_target_labeled = _safe_transform(step, X_target_labeled) + + return Xt, Xt_target_unlabeled, Xt_target_labeled + + +def fit_pipeline_with_subject_metadata( + estimator: Any, + X_train: Any, + y_train: np.ndarray, + subjects_train: Optional[np.ndarray] = None, + cs_mode: Optional[str] = None, + X_target_unlabeled: Optional[Any] = None, + X_target_labeled: Optional[Any] = None, + y_target_labeled: Optional[np.ndarray] = None, +) -> Any: + """ + Fit an estimator or sklearn Pipeline while passing metadata only to steps + that explicitly support it. + + Normal sklearn estimators remain supported because unsupported metadata + arguments are not passed to them. + """ + if not isinstance(estimator, Pipeline): + fit_kwargs = {} + + if subjects_train is not None and estimator_accepts_argument( + estimator, + "fit", + "subjects", + ): + fit_kwargs["subjects"] = subjects_train + + if cs_mode is not None and estimator_accepts_argument( + estimator, + "fit", + "cs_mode", + ): + fit_kwargs["cs_mode"] = cs_mode + + if ( + X_target_unlabeled is not None + and estimator_accepts_argument(estimator, "fit", "X_target_unlabeled") + ): + fit_kwargs["X_target_unlabeled"] = X_target_unlabeled + + if ( + X_target_labeled is not None + and estimator_accepts_argument(estimator, "fit", "X_target_labeled") + ): + fit_kwargs["X_target_labeled"] = X_target_labeled + + if ( + y_target_labeled is not None + and estimator_accepts_argument(estimator, "fit", "y_target_labeled") + ): + fit_kwargs["y_target_labeled"] = y_target_labeled + + estimator.fit(X_train, y_train, **fit_kwargs) + return estimator + + Xt_train = X_train + Xt_target_unlabeled = X_target_unlabeled + Xt_target_labeled = X_target_labeled + + for _step_name, step in estimator.steps[:-1]: + ( + Xt_train, + Xt_target_unlabeled, + Xt_target_labeled, + ) = _fit_transform_step_with_metadata( + step=step, + X=Xt_train, + y=y_train, + subjects=subjects_train, + cs_mode=cs_mode, + X_target_unlabeled=Xt_target_unlabeled, + X_target_labeled=Xt_target_labeled, + y_target_labeled=y_target_labeled, + ) + + _final_name, final_step = estimator.steps[-1] + + final_fit_kwargs = {} + + if subjects_train is not None and estimator_accepts_argument( + final_step, + "fit", + "subjects", + ): + final_fit_kwargs["subjects"] = subjects_train + + if cs_mode is not None and estimator_accepts_argument( + final_step, + "fit", + "cs_mode", + ): + final_fit_kwargs["cs_mode"] = cs_mode + + if ( + Xt_target_unlabeled is not None + and estimator_accepts_argument(final_step, "fit", "X_target_unlabeled") + ): + final_fit_kwargs["X_target_unlabeled"] = Xt_target_unlabeled + + if ( + Xt_target_labeled is not None + and estimator_accepts_argument(final_step, "fit", "X_target_labeled") + ): + final_fit_kwargs["X_target_labeled"] = Xt_target_labeled + + if ( + y_target_labeled is not None + and estimator_accepts_argument(final_step, "fit", "y_target_labeled") + ): + final_fit_kwargs["y_target_labeled"] = y_target_labeled + + final_step.fit(Xt_train, y_train, **final_fit_kwargs) + + return estimator + +class TrialwisePredictWrapper(ClassifierMixin, BaseEstimator): + """ + Wrap an already-fitted estimator and force one-trial-at-a-time prediction. + + This keeps MOABB/sklearn scorer compatibility while preventing the wrapped + estimator from receiving the full target test block during prediction. + + Has multi-class support. + """ + + _estimator_type = "classifier" + + def __init__(self, fitted_estimator): + self.fitted_estimator = fitted_estimator + _ensure_fitted(fitted_estimator) + self.classes_ = self._get_classes(fitted_estimator) + + def fit(self, X, y=None): + raise RuntimeError( + "TrialwisePredictWrapper wraps an already fitted estimator." + ) + + def predict(self, X): + return np.asarray( + [ + self.fitted_estimator.predict(self._slice_one(X, i))[0] + for i in range(len(X)) + ] + ) + + def predict_proba(self, X): + if not hasattr(self.fitted_estimator, "predict_proba"): + raise AttributeError( + "Wrapped estimator does not provide predict_proba." + ) + + rows = [ + self._first_row( + self.fitted_estimator.predict_proba(self._slice_one(X, i)) + ) + for i in range(len(X)) + ] + + return np.vstack(rows) + + def decision_function(self, X): + if not hasattr(self.fitted_estimator, "decision_function"): + raise AttributeError( + "Wrapped estimator does not provide decision_function." + ) + + rows = [ + self._first_row( + self.fitted_estimator.decision_function(self._slice_one(X, i)) + ) + for i in range(len(X)) + ] + + out = np.asarray(rows) + + # sklearn convention: + # binary decision_function -> shape (n_samples,) + # multiclass decision_function -> shape (n_samples, n_classes) + if out.ndim == 2 and out.shape[1] == 1: + return out.ravel() + + return out + + @staticmethod + def _slice_one(X, i): + return X[i : i + 1] + + @staticmethod + def _first_row(output): + arr = np.asarray(output) + + if arr.ndim == 0: + return arr.item() + + if arr.shape[0] == 1: + arr = arr[0] + + return arr + + @staticmethod + def _get_classes(estimator): + if hasattr(estimator, "classes_"): + return estimator.classes_ + + if hasattr(estimator, "steps"): + final_estimator = estimator.steps[-1][1] + if hasattr(final_estimator, "classes_"): + return final_estimator.classes_ + + raise AttributeError( + "Wrapped estimator does not expose classes_." + ) + +# --------------------------------------------------------------------- +# Main evaluation class +# --------------------------------------------------------------------- + +class CrossSubjectTargetAwareEvaluation(CrossSubjectEvaluation): + _eval_type = "CrossSubjectTargetAware" + _score_per_session = True + _needs_all_subjects = True + + def __init__( + self, + *args, + cs_mode: CsMode = CsMode.HOS_SOURCE_ONLY_BLOCKWISE, + **kwargs, + ): + super().__init__(*args, **kwargs) + + if not isinstance(cs_mode, CsMode): + raise ValueError( + "cs_mode must be an instance of CsMode. " + f"Got {cs_mode!r}." + ) + + self.cs_mode = cs_mode + + def _build_task_list( + self, + dataset, + y, + metadata, + splitter, + work_plan, + param_grid, + ): + """ + Build lightweight flattened MOABB 1.6 tasks with target-aware split + metadata. + + The task dictionary intentionally does not store X, y, metadata, groups, + or sessions. These are passed separately to the worker to avoid duplicating + large objects in every joblib task. + """ + groups = metadata["subject"].values + + tasks = [] + + for cv_ind, (train_idx, test_idx) in enumerate(splitter.split(y, metadata)): + train_idx = np.asarray(train_idx, dtype=int) + test_idx = np.asarray(test_idx, dtype=int) + + subject = groups[test_idx[0]] + + if subject in work_plan: + subject_key = subject + elif str(subject) in work_plan: + subject_key = str(subject) + else: + continue + + target_adapt_idx, eval_idx = split_target_for_cs_mode( + test_idx, + self.cs_mode, + ) + + if len(eval_idx) == 0: + warnings.warn( + f"{dataset.code} | subject={subject}: empty evaluation set " + f"for cs_mode={self.cs_mode.name}. Skipping fold.", + RuntimeWarning, + ) + continue + + for pipeline_name, pipeline in work_plan[subject_key].items(): + tasks.append( + { + "dataset": dataset, + "train_idx": train_idx, + "test_idx": test_idx, + "target_adapt_idx": target_adapt_idx, + "eval_idx": eval_idx, + "subject": subject, + "pipeline_name": pipeline_name, + "pipeline": pipeline, + "param_grid": param_grid, + "cv_ind": cv_ind, + } + ) + + return tasks + + def _create_splitter(self): + """ + Create the MOABB 1.6 cross-subject splitter. + + This delegates subject-level CV handling to MOABB's CrossSubjectSplitter, + so cv_class, random_state, and cv_kwargs keep the same meaning as in + CrossSubjectEvaluation. + """ + cv_kwargs = getattr(self, "cv_kwargs", {}) or {} + cv_class = getattr(self, "cv_class", None) + + if cv_class is None: + return CrossSubjectSplitter( + random_state=self.random_state, + **cv_kwargs, + ) + + return CrossSubjectSplitter( + cv_class=cv_class, + random_state=self.random_state, + **cv_kwargs, + ) + + def _evaluate_task(self, task, X, y, groups, sessions): + """ + Evaluate one flattened target-aware task. + + Large shared objects are passed as worker arguments instead of being stored + inside each task dictionary. + """ + dataset = task["dataset"] + + train_idx = task["train_idx"] + test_idx = task["test_idx"] + target_adapt_idx = task["target_adapt_idx"] + eval_idx = task["eval_idx"] + + subject = task["subject"] + name = task["pipeline_name"] + clf = task["pipeline"] + param_grid = task["param_grid"] + cv_ind = task["cv_ind"] + + if param_grid is not None: + raise NotImplementedError( + "param_grid/grid search is not supported by " + "CrossSubjectTargetAwareEvaluation yet. Inner GridSearchCV does " + "not pass subjects, X_target_unlabeled, X_target_labeled, or " + "y_target_labeled to inner fits. Please set param_grid=None." + ) + + nchan = self._get_nchan(X) + + cvclf = clone(clf) + + X_train = X[train_idx] + + if self.mne_labels: + y_train = y[train_idx] + y_eval_all = y + else: + fold_label_idx = np.unique( + np.concatenate([train_idx, test_idx]) + ) + + le = LabelEncoder() + le.fit(y[fold_label_idx]) + + y_train = le.transform(y[train_idx]) + + y_eval_all = np.empty_like(y, dtype=int) + y_eval_all[fold_label_idx] = le.transform(y[fold_label_idx]) + + subjects_train = groups[train_idx] + + X_target_unlabeled = None + X_target_labeled = None + y_target_labeled = None + + n_target_unlabeled = 0 + n_target_labeled = 0 + + if cs_mode_uses_unlabeled_target(self.cs_mode): + X_target_unlabeled = X[target_adapt_idx] + n_target_unlabeled = int(len(target_adapt_idx)) + + elif cs_mode_uses_labeled_target(self.cs_mode): + X_target_labeled = X[target_adapt_idx] + y_target_labeled = y_eval_all[target_adapt_idx] + n_target_labeled = int(len(target_adapt_idx)) + + duration = self._fit_estimator_with_target_metadata( + estimator=cvclf, + X_train=X_train, + y_train=y_train, + subjects_train=subjects_train, + cs_mode=self.cs_mode.name, + X_target_unlabeled=X_target_unlabeled, + X_target_labeled=X_target_labeled, + y_target_labeled=y_target_labeled, + ) + + self._maybe_save_model_cv( + cvclf, + dataset, + subject, + "", + name, + cv_ind, + eval_type=self._eval_type, + ) + + if self.cs_mode == CsMode.HOS_SOURCE_ONLY_TRIALWISE: + scoring_estimator = TrialwisePredictWrapper(cvclf) + else: + scoring_estimator = cvclf + + scorer = _create_scorer( + scoring_estimator, + self.paradigm.scoring, + ) + + results = [] + + for session in np.unique(sessions[eval_idx]): + session_eval_idx = eval_idx[sessions[eval_idx] == session] + + if len(session_eval_idx) == 0: + continue + + res = self._build_scored_result( + dataset=dataset, + subject=subject, + session=session, + pipeline=name, + n_samples=len(X_train), + n_channels=nchan, + duration=duration, + scorer=scorer, + model=scoring_estimator, + X_test=X[session_eval_idx], + y_test=y_eval_all[session_eval_idx], + ) + + res["cs_mode"] = self.cs_mode.name + res["n_source_train"] = int(len(train_idx)) + res["n_source_fit"] = int(len(X_train)) + res["n_train_total"] = int(len(X_train) + n_target_labeled) + res["n_heldout_total"] = int(len(test_idx)) + res["n_target_adapt"] = int(len(target_adapt_idx)) + res["n_target_eval"] = int(len(eval_idx)) + res["n_target_unlabeled"] = int(n_target_unlabeled) + res["n_target_labeled"] = int(n_target_labeled) + res["target_subject"] = subject + + if self.cs_mode == CsMode.HOS_SOURCE_ONLY_TRIALWISE: + res["predict_mode"] = "trialwise" + else: + res["predict_mode"] = "blockwise" + + results.append(res) + + return results + + def _evaluate_parallel_dataset( + self, + dataset, + pipelines, + param_grid, + process_pipeline, + postprocess_pipeline, + work_plan, + ): + """ + MOABB > 1.6-native flattened parallel evaluation. + + One task = one held-out subject fold x one pipeline. + + The task dictionaries are kept lightweight. Large shared objects are passed + separately to the worker. + """ + from joblib import Parallel, delayed + + if param_grid is not None: + raise NotImplementedError( + "param_grid/grid search is not supported by " + "CrossSubjectTargetAwareEvaluation yet. Inner GridSearchCV does " + "not pass subjects, X_target_unlabeled, X_target_labeled, or " + "y_target_labeled to inner fits. Please set param_grid=None." + ) + + subjects_to_load = ( + dataset.subject_list + if getattr(self, "_needs_all_subjects", False) + else list(work_plan.keys()) + ) + + run_pipes = { + name: pipe + for subject_pipelines in work_plan.values() + for name, pipe in subject_pipelines.items() + } + + X, y_raw, metadata = self._load_data( + dataset=dataset, + run_pipes=run_pipes, + process_pipeline=process_pipeline, + postprocess_pipeline=postprocess_pipeline, + subjects=subjects_to_load, + ) + + y = np.asarray(y_raw) + + groups = metadata["subject"].values + sessions = metadata["session"].values + + splitter = self._create_splitter() + + tasks = self._build_task_list( + dataset=dataset, + y=y, + metadata=metadata, + splitter=splitter, + work_plan=work_plan, + param_grid=param_grid, + ) + + if not tasks: + return [] + + desc = f"{dataset.code}-{self._eval_type}" + + if self.n_jobs == 1: + nested_results = [] + + for task in tqdm( + tasks, + total=len(tasks), + desc=desc, + unit="task", + dynamic_ncols=True, + ): + nested_results.append( + self._evaluate_task( + task=task, + X=X, + y=y, + groups=groups, + sessions=sessions, + ) + ) + + else: + with tqdm_joblib( + tqdm( + total=len(tasks), + desc=desc, + unit="task", + dynamic_ncols=True, + ) + ): + nested_results = Parallel(n_jobs=self.n_jobs, verbose=0)( + delayed(self._evaluate_task)( + task=task, + X=X, + y=y, + groups=groups, + sessions=sessions, + ) + for task in tasks + ) + + all_results = [] + + for rows in nested_results: + all_results.extend(rows) + + return all_results + + def _fit_estimator_with_target_metadata( + self, + estimator: Any, + X_train: Any, + y_train: np.ndarray, + subjects_train: Optional[np.ndarray] = None, + cs_mode: Optional[str] = None, + X_target_unlabeled: Optional[Any] = None, + X_target_labeled: Optional[Any] = None, + y_target_labeled: Optional[np.ndarray] = None, + ): + """ + Fit one estimator with optional target-aware metadata. + """ + start_time = time.time() + + fit_pipeline_with_subject_metadata( + estimator=estimator, + X_train=X_train, + y_train=y_train, + subjects_train=subjects_train, + cs_mode=cs_mode, + X_target_unlabeled=X_target_unlabeled, + X_target_labeled=X_target_labeled, + y_target_labeled=y_target_labeled, + ) + + _ensure_fitted(estimator) + + duration = time.time() - start_time + return duration From e28e29db4daeec412de61f20d04b54c6c34e476e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Jun 2026 09:12:03 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks --- ...cross_subject_transfer_learning_example.py | 102 ++--- .../cross_subject_target_aware_evaluation.py | 352 +++++++----------- 2 files changed, 168 insertions(+), 286 deletions(-) diff --git a/examples/how_to_benchmark/cross_subject_transfer_learning_example.py b/examples/how_to_benchmark/cross_subject_transfer_learning_example.py index b1643b1af..411864660 100644 --- a/examples/how_to_benchmark/cross_subject_transfer_learning_example.py +++ b/examples/how_to_benchmark/cross_subject_transfer_learning_example.py @@ -40,23 +40,21 @@ import numpy as np import pandas as pd - +from pyriemann.estimation import Covariances +from pyriemann.tangentspace import TangentSpace +from pyriemann.utils.mean import mean_riemann from sklearn.base import BaseEstimator, TransformerMixin from sklearn.linear_model import LogisticRegression from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler -from pyriemann.estimation import Covariances -from pyriemann.tangentspace import TangentSpace -from pyriemann.utils.mean import mean_riemann - -from moabb.datasets import BNCI2014_001, Weibo2014 -from moabb.paradigms import LeftRightImagery - +from moabb.datasets import Weibo2014 from moabb.evaluations.cross_subject_target_aware_evaluation import ( CrossSubjectTargetAwareEvaluation, CsMode, ) +from moabb.paradigms import LeftRightImagery + # --------------------------------------------------------------------- # SPD helpers @@ -77,8 +75,7 @@ def safe_mean_riemann(X: np.ndarray) -> np.ndarray: if X.ndim != 3 or len(X) == 0: raise ValueError( - "safe_mean_riemann expects shape " - "(n_matrices, n_channels, n_channels)." + "safe_mean_riemann expects shape (n_matrices, n_channels, n_channels)." ) X = np.stack([nearest_spd_jitter(C) for C in X], axis=0) @@ -89,16 +86,13 @@ def powerm_spd(C: np.ndarray, power: float, eps: float = 1e-12) -> np.ndarray: C = nearest_spd_jitter(C) vals, vecs = np.linalg.eigh(C) vals = np.maximum(vals, eps) - return symmetrize((vecs * (vals ** power)) @ vecs.T) + return symmetrize((vecs * (vals**power)) @ vecs.T) def batch_congruence(X: np.ndarray, A: np.ndarray) -> np.ndarray: X = np.asarray(X) - return np.stack( - [nearest_spd_jitter(A @ C @ A.T) for C in X], - axis=0, - ) + return np.stack([nearest_spd_jitter(A @ C @ A.T) for C in X], axis=0) # --------------------------------------------------------------------- @@ -160,10 +154,7 @@ def _make_transform(self, M_ref: np.ndarray, M_source: np.ndarray) -> np.ndarray if alpha == 0.0: return np.eye(M_ref.shape[0]) - return ( - powerm_spd(M_ref, 0.5 * alpha) - @ powerm_spd(M_source, -0.5 * alpha) - ) + return powerm_spd(M_ref, 0.5 * alpha) @ powerm_spd(M_source, -0.5 * alpha) def fit( self, @@ -271,14 +262,10 @@ def fit( self.source_transforms_: Dict[str, np.ndarray] = {} for s, M_s in self.source_means_.items(): - self.source_transforms_[s] = self._make_transform( - self.reference_mean_, - M_s, - ) + self.source_transforms_[s] = self._make_transform(self.reference_mean_, M_s) self.global_transform_ = self._make_transform( - self.reference_mean_, - self.global_source_mean_, + self.reference_mean_, self.global_source_mean_ ) # The evaluation mode decides which target data is provided. @@ -321,8 +308,7 @@ def fit( target_mean = safe_mean_riemann(X_target_for_alignment) self.target_transform_ = self._make_transform( - self.reference_mean_, - target_mean, + self.reference_mean_, target_mean ) self.has_target_data_ = True @@ -435,47 +421,39 @@ def fit_transform( def _check_is_fitted(self): if not hasattr(self, "reference_mean_"): raise RuntimeError("RiemannianProcrustesAlignment is not fitted.") + + # --------------------------------------------------------------------- # Demo pipelines # --------------------------------------------------------------------- + def make_pipelines(): ts_lr = make_pipeline( Covariances(estimator="oas"), TangentSpace(metric="riemann"), StandardScaler(), LogisticRegression( - C=1.0, - class_weight="balanced", - max_iter=5000, - random_state=42, + C=1.0, class_weight="balanced", max_iter=5000, random_state=42 ), ) rpa_ts_lr = make_pipeline( Covariances(estimator="oas"), - RiemannianProcrustesAlignment( reference_subject="auto", alignment_strength=1.0, use_global_transform_for_unseen=True, verbose=True, ), - TangentSpace(metric="riemann"), StandardScaler(), LogisticRegression( - C=1.0, - class_weight="balanced", - max_iter=5000, - random_state=42, + C=1.0, class_weight="balanced", max_iter=5000, random_state=42 ), ) - return { - "TS + LR": ts_lr, - "RPA + TS + LR": rpa_ts_lr, - } + return {"TS + LR": ts_lr, "RPA + TS + LR": rpa_ts_lr} # --------------------------------------------------------------------- @@ -499,8 +477,7 @@ def summarize_per_dataset_pipeline(results: pd.DataFrame) -> pd.DataFrame: results = normalize_results(results) summary = ( - results - .groupby(["dataset", "pipeline"], as_index=False) + results.groupby(["dataset", "pipeline"], as_index=False) .agg( n_folds=("score", "count"), mean_ROC_AUC=("score", "mean"), @@ -521,8 +498,7 @@ def summarize_global_pipeline(summary: pd.DataFrame) -> pd.DataFrame: This avoids giving more weight to datasets with more subjects/sessions. """ global_summary = ( - summary - .groupby("pipeline", as_index=False) + summary.groupby("pipeline", as_index=False) .agg( n_datasets=("dataset", "nunique"), total_folds=("n_folds", "sum"), @@ -542,32 +518,27 @@ def make_dataset_pipeline_table(summary: pd.DataFrame) -> pd.DataFrame: Useful for quick visual comparison. """ table = summary.pivot( - index="dataset", - columns="pipeline", - values="mean_ROC_AUC", + index="dataset", columns="pipeline", values="mean_ROC_AUC" ).reset_index() table.columns.name = None return table + # --------------------------------------------------------------------- # Main # --------------------------------------------------------------------- def main(): - dataset = Weibo2014()#BNCI2014_001() + dataset = Weibo2014() # BNCI2014_001() # Uncomment for a fast smoke test. # dataset.subject_list = [1, 2, 3, 4, 5] datasets = [dataset] - paradigm = LeftRightImagery( - fmin=8.0, - fmax=32.0, - resample=128, - ) + paradigm = LeftRightImagery(fmin=8.0, fmax=32.0, resample=128) pipelines = make_pipelines() @@ -587,16 +558,18 @@ def main(): # No target adaptation. This is closest to standard MOABB # cross-subject block prediction. # - cs_mode = CsMode.HOS_UNLABELED_20P # RPA aligns using the first 20% target data without labels. - #cs_mode = CsMode.HOS_UNLABELED_50P # RPA aligns using the first 20% target data without labels. - #cs_mode = CsMode.HOS_LABELED_20P # RPA aligns using the first 20% target data, ignoring labels. - #cs_mode = CsMode.HOS_SOURCE_ONLY_BLOCKWISE # RPA uses source-subject alignment only; no target adaptation. + cs_mode = ( + CsMode.HOS_UNLABELED_20P + ) # RPA aligns using the first 20% target data without labels. + # cs_mode = CsMode.HOS_UNLABELED_50P # RPA aligns using the first 20% target data without labels. + # cs_mode = CsMode.HOS_LABELED_20P # RPA aligns using the first 20% target data, ignoring labels. + # cs_mode = CsMode.HOS_SOURCE_ONLY_BLOCKWISE # RPA uses source-subject alignment only; no target adaptation. evaluation = CrossSubjectTargetAwareEvaluation( paradigm=paradigm, datasets=datasets, cs_mode=cs_mode, - n_jobs=4, # 1 while debugging verbose RPA output + n_jobs=4, # 1 while debugging verbose RPA output overwrite=True, random_state=42, ) @@ -604,13 +577,7 @@ def main(): results = evaluation.process(pipelines=pipelines) results = normalize_results(results) - useful_cols = [ - "dataset", - "subject", - "session", - "pipeline", - "score", - ] + useful_cols = ["dataset", "subject", "session", "pipeline", "score"] useful_cols = [c for c in useful_cols if c in results.columns] @@ -627,5 +594,6 @@ def main(): print("\nGlobal per-pipeline summary:") print(global_summary.to_string(index=False)) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/moabb/evaluations/cross_subject_target_aware_evaluation.py b/moabb/evaluations/cross_subject_target_aware_evaluation.py index fb606d456..e2858e276 100644 --- a/moabb/evaluations/cross_subject_target_aware_evaluation.py +++ b/moabb/evaluations/cross_subject_target_aware_evaluation.py @@ -34,7 +34,7 @@ evaluated. Important notes: - + 1) HOS_SOURCE_ONLY_BLOCKWISE vs HOS_UNLABELED_100P HOS_SOURCE_ONLY_BLOCKWISE predicts the target-subject samples as a @@ -58,7 +58,7 @@ HOS_UNLABEnbLED_20P/50P/100P modes work only if a pipeline step explicitly accepts X_target_unlabeled. Regular pipelines will silently ignore the target adaptation data, as they are unaware of it. - + 4) Split order defines adaptation data The “first 20%” depends on trial order. @@ -75,11 +75,9 @@ import joblib import numpy as np - from sklearn.base import BaseEstimator, ClassifierMixin, clone from sklearn.pipeline import Pipeline from sklearn.preprocessing import LabelEncoder - from tqdm import tqdm from moabb.evaluations import CrossSubjectEvaluation, CrossSubjectSplitter @@ -113,12 +111,14 @@ def __call__(self, *args, **kwargs): class CsMode(Enum): - HOS_SOURCE_ONLY_TRIALWISE = auto() # Source-only training; each held-out target trial is predicted independently. + HOS_SOURCE_ONLY_TRIALWISE = ( + auto() + ) # Source-only training; each held-out target trial is predicted independently. HOS_SOURCE_ONLY_BLOCKWISE = auto() # Source-only training; held-out target trials are predicted as a block, matching standard MOABB behavior. - HOS_UNLABELED_20P = auto() # First 20% of held-out target trials are used unlabeled for adaptation; remaining 80% are evaluated. - HOS_UNLABELED_50P = auto() # First 50% of held-out target trials are used unlabeled for adaptation; remaining 50% are evaluated. - HOS_UNLABELED_100P = auto() # All held-out target trials are used unlabeled for transductive adaptation and also evaluated. - HOS_LABELED_20P = auto() # First 20% of held-out target trials are used with labels for supervised calibration; remaining 80% are evaluated. + HOS_UNLABELED_20P = auto() # First 20% of held-out target trials are used unlabeled for adaptation; remaining 80% are evaluated. + HOS_UNLABELED_50P = auto() # First 50% of held-out target trials are used unlabeled for adaptation; remaining 50% are evaluated. + HOS_UNLABELED_100P = auto() # All held-out target trials are used unlabeled for transductive adaptation and also evaluated. + HOS_LABELED_20P = auto() # First 20% of held-out target trials are used with labels for supervised calibration; remaining 80% are evaluated. def cs_mode_uses_unlabeled_target(mode: CsMode) -> bool: @@ -134,10 +134,7 @@ def cs_mode_uses_labeled_target(mode: CsMode) -> bool: def cs_mode_target_fraction(mode: CsMode) -> float: - if mode in { - CsMode.HOS_SOURCE_ONLY_TRIALWISE, - CsMode.HOS_SOURCE_ONLY_BLOCKWISE, - }: + if mode in {CsMode.HOS_SOURCE_ONLY_TRIALWISE, CsMode.HOS_SOURCE_ONLY_BLOCKWISE}: return 0.0 if mode == CsMode.HOS_UNLABELED_20P: @@ -156,8 +153,7 @@ def cs_mode_target_fraction(mode: CsMode) -> float: def split_target_for_cs_mode( - test_idx: np.ndarray, - cs_mode: CsMode, + test_idx: np.ndarray, cs_mode: CsMode ) -> tuple[np.ndarray, np.ndarray]: test_idx = np.asarray(test_idx, dtype=int) @@ -187,11 +183,7 @@ def split_target_for_cs_mode( # --------------------------------------------------------------------- -def estimator_accepts_argument( - estimator: Any, - method_name: str, - arg_name: str, -) -> bool: +def estimator_accepts_argument(estimator: Any, method_name: str, arg_name: str) -> bool: """ Return True if an estimator method explicitly accepts a metadata argument. @@ -240,21 +232,18 @@ def _fit_transform_step_with_metadata( if cs_mode is not None and estimator_accepts_argument(step, "fit", "cs_mode"): fit_kwargs["cs_mode"] = cs_mode - if ( - X_target_unlabeled is not None - and estimator_accepts_argument(step, "fit", "X_target_unlabeled") + if X_target_unlabeled is not None and estimator_accepts_argument( + step, "fit", "X_target_unlabeled" ): fit_kwargs["X_target_unlabeled"] = X_target_unlabeled - if ( - X_target_labeled is not None - and estimator_accepts_argument(step, "fit", "X_target_labeled") + if X_target_labeled is not None and estimator_accepts_argument( + step, "fit", "X_target_labeled" ): fit_kwargs["X_target_labeled"] = X_target_labeled - if ( - y_target_labeled is not None - and estimator_accepts_argument(step, "fit", "y_target_labeled") + if y_target_labeled is not None and estimator_accepts_argument( + step, "fit", "y_target_labeled" ): fit_kwargs["y_target_labeled"] = y_target_labeled @@ -266,12 +255,12 @@ def _fit_transform_step_with_metadata( Xt = step.fit_transform(X, y) else: step.fit(X, y, **fit_kwargs) - + if not hasattr(step, "transform"): raise TypeError( f"Intermediate pipeline step {step!r} has no transform method." ) - + Xt = step.transform(X) Xt_target_unlabeled = _safe_transform(step, X_target_unlabeled) @@ -301,34 +290,27 @@ def fit_pipeline_with_subject_metadata( fit_kwargs = {} if subjects_train is not None and estimator_accepts_argument( - estimator, - "fit", - "subjects", + estimator, "fit", "subjects" ): fit_kwargs["subjects"] = subjects_train if cs_mode is not None and estimator_accepts_argument( - estimator, - "fit", - "cs_mode", + estimator, "fit", "cs_mode" ): fit_kwargs["cs_mode"] = cs_mode - if ( - X_target_unlabeled is not None - and estimator_accepts_argument(estimator, "fit", "X_target_unlabeled") + if X_target_unlabeled is not None and estimator_accepts_argument( + estimator, "fit", "X_target_unlabeled" ): fit_kwargs["X_target_unlabeled"] = X_target_unlabeled - if ( - X_target_labeled is not None - and estimator_accepts_argument(estimator, "fit", "X_target_labeled") + if X_target_labeled is not None and estimator_accepts_argument( + estimator, "fit", "X_target_labeled" ): fit_kwargs["X_target_labeled"] = X_target_labeled - if ( - y_target_labeled is not None - and estimator_accepts_argument(estimator, "fit", "y_target_labeled") + if y_target_labeled is not None and estimator_accepts_argument( + estimator, "fit", "y_target_labeled" ): fit_kwargs["y_target_labeled"] = y_target_labeled @@ -340,19 +322,17 @@ def fit_pipeline_with_subject_metadata( Xt_target_labeled = X_target_labeled for _step_name, step in estimator.steps[:-1]: - ( - Xt_train, - Xt_target_unlabeled, - Xt_target_labeled, - ) = _fit_transform_step_with_metadata( - step=step, - X=Xt_train, - y=y_train, - subjects=subjects_train, - cs_mode=cs_mode, - X_target_unlabeled=Xt_target_unlabeled, - X_target_labeled=Xt_target_labeled, - y_target_labeled=y_target_labeled, + (Xt_train, Xt_target_unlabeled, Xt_target_labeled) = ( + _fit_transform_step_with_metadata( + step=step, + X=Xt_train, + y=y_train, + subjects=subjects_train, + cs_mode=cs_mode, + X_target_unlabeled=Xt_target_unlabeled, + X_target_labeled=Xt_target_labeled, + y_target_labeled=y_target_labeled, + ) ) _final_name, final_step = estimator.steps[-1] @@ -360,34 +340,25 @@ def fit_pipeline_with_subject_metadata( final_fit_kwargs = {} if subjects_train is not None and estimator_accepts_argument( - final_step, - "fit", - "subjects", + final_step, "fit", "subjects" ): final_fit_kwargs["subjects"] = subjects_train - if cs_mode is not None and estimator_accepts_argument( - final_step, - "fit", - "cs_mode", - ): + if cs_mode is not None and estimator_accepts_argument(final_step, "fit", "cs_mode"): final_fit_kwargs["cs_mode"] = cs_mode - if ( - Xt_target_unlabeled is not None - and estimator_accepts_argument(final_step, "fit", "X_target_unlabeled") + if Xt_target_unlabeled is not None and estimator_accepts_argument( + final_step, "fit", "X_target_unlabeled" ): final_fit_kwargs["X_target_unlabeled"] = Xt_target_unlabeled - if ( - Xt_target_labeled is not None - and estimator_accepts_argument(final_step, "fit", "X_target_labeled") + if Xt_target_labeled is not None and estimator_accepts_argument( + final_step, "fit", "X_target_labeled" ): final_fit_kwargs["X_target_labeled"] = Xt_target_labeled - if ( - y_target_labeled is not None - and estimator_accepts_argument(final_step, "fit", "y_target_labeled") + if y_target_labeled is not None and estimator_accepts_argument( + final_step, "fit", "y_target_labeled" ): final_fit_kwargs["y_target_labeled"] = y_target_labeled @@ -395,13 +366,14 @@ def fit_pipeline_with_subject_metadata( return estimator + class TrialwisePredictWrapper(ClassifierMixin, BaseEstimator): """ Wrap an already-fitted estimator and force one-trial-at-a-time prediction. This keeps MOABB/sklearn scorer compatibility while preventing the wrapped estimator from receiving the full target test block during prediction. - + Has multi-class support. """ @@ -413,9 +385,7 @@ def __init__(self, fitted_estimator): self.classes_ = self._get_classes(fitted_estimator) def fit(self, X, y=None): - raise RuntimeError( - "TrialwisePredictWrapper wraps an already fitted estimator." - ) + raise RuntimeError("TrialwisePredictWrapper wraps an already fitted estimator.") def predict(self, X): return np.asarray( @@ -427,14 +397,10 @@ def predict(self, X): def predict_proba(self, X): if not hasattr(self.fitted_estimator, "predict_proba"): - raise AttributeError( - "Wrapped estimator does not provide predict_proba." - ) + raise AttributeError("Wrapped estimator does not provide predict_proba.") rows = [ - self._first_row( - self.fitted_estimator.predict_proba(self._slice_one(X, i)) - ) + self._first_row(self.fitted_estimator.predict_proba(self._slice_one(X, i))) for i in range(len(X)) ] @@ -442,9 +408,7 @@ def predict_proba(self, X): def decision_function(self, X): if not hasattr(self.fitted_estimator, "decision_function"): - raise AttributeError( - "Wrapped estimator does not provide decision_function." - ) + raise AttributeError("Wrapped estimator does not provide decision_function.") rows = [ self._first_row( @@ -489,74 +453,57 @@ def _get_classes(estimator): if hasattr(final_estimator, "classes_"): return final_estimator.classes_ - raise AttributeError( - "Wrapped estimator does not expose classes_." - ) + raise AttributeError("Wrapped estimator does not expose classes_.") + # --------------------------------------------------------------------- # Main evaluation class # --------------------------------------------------------------------- + class CrossSubjectTargetAwareEvaluation(CrossSubjectEvaluation): _eval_type = "CrossSubjectTargetAware" _score_per_session = True _needs_all_subjects = True def __init__( - self, - *args, - cs_mode: CsMode = CsMode.HOS_SOURCE_ONLY_BLOCKWISE, - **kwargs, + self, *args, cs_mode: CsMode = CsMode.HOS_SOURCE_ONLY_BLOCKWISE, **kwargs ): super().__init__(*args, **kwargs) if not isinstance(cs_mode, CsMode): - raise ValueError( - "cs_mode must be an instance of CsMode. " - f"Got {cs_mode!r}." - ) + raise ValueError(f"cs_mode must be an instance of CsMode. Got {cs_mode!r}.") self.cs_mode = cs_mode - def _build_task_list( - self, - dataset, - y, - metadata, - splitter, - work_plan, - param_grid, - ): + def _build_task_list(self, dataset, y, metadata, splitter, work_plan, param_grid): """ Build lightweight flattened MOABB 1.6 tasks with target-aware split metadata. - + The task dictionary intentionally does not store X, y, metadata, groups, or sessions. These are passed separately to the worker to avoid duplicating large objects in every joblib task. """ groups = metadata["subject"].values - + tasks = [] - + for cv_ind, (train_idx, test_idx) in enumerate(splitter.split(y, metadata)): train_idx = np.asarray(train_idx, dtype=int) test_idx = np.asarray(test_idx, dtype=int) - + subject = groups[test_idx[0]] - + if subject in work_plan: subject_key = subject elif str(subject) in work_plan: subject_key = str(subject) else: continue - - target_adapt_idx, eval_idx = split_target_for_cs_mode( - test_idx, - self.cs_mode, - ) - + + target_adapt_idx, eval_idx = split_target_for_cs_mode(test_idx, self.cs_mode) + if len(eval_idx) == 0: warnings.warn( f"{dataset.code} | subject={subject}: empty evaluation set " @@ -564,7 +511,7 @@ def _build_task_list( RuntimeWarning, ) continue - + for pipeline_name, pipeline in work_plan[subject_key].items(): tasks.append( { @@ -580,52 +527,47 @@ def _build_task_list( "cv_ind": cv_ind, } ) - + return tasks - + def _create_splitter(self): """ Create the MOABB 1.6 cross-subject splitter. - + This delegates subject-level CV handling to MOABB's CrossSubjectSplitter, so cv_class, random_state, and cv_kwargs keep the same meaning as in CrossSubjectEvaluation. """ cv_kwargs = getattr(self, "cv_kwargs", {}) or {} cv_class = getattr(self, "cv_class", None) - + if cv_class is None: - return CrossSubjectSplitter( - random_state=self.random_state, - **cv_kwargs, - ) - + return CrossSubjectSplitter(random_state=self.random_state, **cv_kwargs) + return CrossSubjectSplitter( - cv_class=cv_class, - random_state=self.random_state, - **cv_kwargs, + cv_class=cv_class, random_state=self.random_state, **cv_kwargs ) - + def _evaluate_task(self, task, X, y, groups, sessions): """ Evaluate one flattened target-aware task. - + Large shared objects are passed as worker arguments instead of being stored inside each task dictionary. """ dataset = task["dataset"] - + train_idx = task["train_idx"] test_idx = task["test_idx"] target_adapt_idx = task["target_adapt_idx"] eval_idx = task["eval_idx"] - + subject = task["subject"] name = task["pipeline_name"] clf = task["pipeline"] param_grid = task["param_grid"] cv_ind = task["cv_ind"] - + if param_grid is not None: raise NotImplementedError( "param_grid/grid search is not supported by " @@ -633,47 +575,45 @@ def _evaluate_task(self, task, X, y, groups, sessions): "not pass subjects, X_target_unlabeled, X_target_labeled, or " "y_target_labeled to inner fits. Please set param_grid=None." ) - + nchan = self._get_nchan(X) - + cvclf = clone(clf) - + X_train = X[train_idx] if self.mne_labels: y_train = y[train_idx] y_eval_all = y else: - fold_label_idx = np.unique( - np.concatenate([train_idx, test_idx]) - ) - + fold_label_idx = np.unique(np.concatenate([train_idx, test_idx])) + le = LabelEncoder() le.fit(y[fold_label_idx]) - + y_train = le.transform(y[train_idx]) - + y_eval_all = np.empty_like(y, dtype=int) y_eval_all[fold_label_idx] = le.transform(y[fold_label_idx]) - + subjects_train = groups[train_idx] - + X_target_unlabeled = None X_target_labeled = None y_target_labeled = None - + n_target_unlabeled = 0 n_target_labeled = 0 - + if cs_mode_uses_unlabeled_target(self.cs_mode): X_target_unlabeled = X[target_adapt_idx] n_target_unlabeled = int(len(target_adapt_idx)) - + elif cs_mode_uses_labeled_target(self.cs_mode): X_target_labeled = X[target_adapt_idx] y_target_labeled = y_eval_all[target_adapt_idx] n_target_labeled = int(len(target_adapt_idx)) - + duration = self._fit_estimator_with_target_metadata( estimator=cvclf, X_train=X_train, @@ -684,35 +624,26 @@ def _evaluate_task(self, task, X, y, groups, sessions): X_target_labeled=X_target_labeled, y_target_labeled=y_target_labeled, ) - + self._maybe_save_model_cv( - cvclf, - dataset, - subject, - "", - name, - cv_ind, - eval_type=self._eval_type, + cvclf, dataset, subject, "", name, cv_ind, eval_type=self._eval_type ) - + if self.cs_mode == CsMode.HOS_SOURCE_ONLY_TRIALWISE: scoring_estimator = TrialwisePredictWrapper(cvclf) else: scoring_estimator = cvclf - - scorer = _create_scorer( - scoring_estimator, - self.paradigm.scoring, - ) - + + scorer = _create_scorer(scoring_estimator, self.paradigm.scoring) + results = [] - + for session in np.unique(sessions[eval_idx]): session_eval_idx = eval_idx[sessions[eval_idx] == session] - + if len(session_eval_idx) == 0: continue - + res = self._build_scored_result( dataset=dataset, subject=subject, @@ -726,7 +657,7 @@ def _evaluate_task(self, task, X, y, groups, sessions): X_test=X[session_eval_idx], y_test=y_eval_all[session_eval_idx], ) - + res["cs_mode"] = self.cs_mode.name res["n_source_train"] = int(len(train_idx)) res["n_source_fit"] = int(len(X_train)) @@ -737,16 +668,16 @@ def _evaluate_task(self, task, X, y, groups, sessions): res["n_target_unlabeled"] = int(n_target_unlabeled) res["n_target_labeled"] = int(n_target_labeled) res["target_subject"] = subject - + if self.cs_mode == CsMode.HOS_SOURCE_ONLY_TRIALWISE: res["predict_mode"] = "trialwise" else: res["predict_mode"] = "blockwise" - + results.append(res) - + return results - + def _evaluate_parallel_dataset( self, dataset, @@ -758,14 +689,14 @@ def _evaluate_parallel_dataset( ): """ MOABB > 1.6-native flattened parallel evaluation. - + One task = one held-out subject fold x one pipeline. - + The task dictionaries are kept lightweight. Large shared objects are passed separately to the worker. """ from joblib import Parallel, delayed - + if param_grid is not None: raise NotImplementedError( "param_grid/grid search is not supported by " @@ -773,19 +704,19 @@ def _evaluate_parallel_dataset( "not pass subjects, X_target_unlabeled, X_target_labeled, or " "y_target_labeled to inner fits. Please set param_grid=None." ) - + subjects_to_load = ( dataset.subject_list if getattr(self, "_needs_all_subjects", False) else list(work_plan.keys()) ) - + run_pipes = { name: pipe for subject_pipelines in work_plan.values() for name, pipe in subject_pipelines.items() } - + X, y_raw, metadata = self._load_data( dataset=dataset, run_pipes=run_pipes, @@ -793,14 +724,14 @@ def _evaluate_parallel_dataset( postprocess_pipeline=postprocess_pipeline, subjects=subjects_to_load, ) - + y = np.asarray(y_raw) - + groups = metadata["subject"].values sessions = metadata["session"].values - + splitter = self._create_splitter() - + tasks = self._build_task_list( dataset=dataset, y=y, @@ -809,59 +740,42 @@ def _evaluate_parallel_dataset( work_plan=work_plan, param_grid=param_grid, ) - + if not tasks: return [] - + desc = f"{dataset.code}-{self._eval_type}" - + if self.n_jobs == 1: nested_results = [] - + for task in tqdm( - tasks, - total=len(tasks), - desc=desc, - unit="task", - dynamic_ncols=True, + tasks, total=len(tasks), desc=desc, unit="task", dynamic_ncols=True ): nested_results.append( self._evaluate_task( - task=task, - X=X, - y=y, - groups=groups, - sessions=sessions, + task=task, X=X, y=y, groups=groups, sessions=sessions ) ) - + else: with tqdm_joblib( - tqdm( - total=len(tasks), - desc=desc, - unit="task", - dynamic_ncols=True, - ) + tqdm(total=len(tasks), desc=desc, unit="task", dynamic_ncols=True) ): nested_results = Parallel(n_jobs=self.n_jobs, verbose=0)( delayed(self._evaluate_task)( - task=task, - X=X, - y=y, - groups=groups, - sessions=sessions, + task=task, X=X, y=y, groups=groups, sessions=sessions ) for task in tasks ) - + all_results = [] - + for rows in nested_results: all_results.extend(rows) - + return all_results - + def _fit_estimator_with_target_metadata( self, estimator: Any, @@ -877,7 +791,7 @@ def _fit_estimator_with_target_metadata( Fit one estimator with optional target-aware metadata. """ start_time = time.time() - + fit_pipeline_with_subject_metadata( estimator=estimator, X_train=X_train, @@ -888,8 +802,8 @@ def _fit_estimator_with_target_metadata( X_target_labeled=X_target_labeled, y_target_labeled=y_target_labeled, ) - + _ensure_fitted(estimator) - + duration = time.time() - start_time return duration