diff --git a/examples/how_to_benchmark/cross_subject_transfer_learning_example.py b/examples/how_to_benchmark/cross_subject_transfer_learning_example.py new file mode 100644 index 000000000..411864660 --- /dev/null +++ b/examples/how_to_benchmark/cross_subject_transfer_learning_example.py @@ -0,0 +1,599 @@ +""" +Example of the transfer-learning-oriented +CrossSubjectTargetAwareEvaluation. + +This example compares a standard TS + LR pipeline with an RPA + TS + LR +pipeline. Riemannian Procrustes Alignment is used here as a simple example of +a target-aware transfer-learning method: it can use source-subject structure and, +when allowed by the evaluation mode, unlabeled target covariance data for +alignment. + +The script can demonstrate 4 of the 6 available modes by changing `cs_mode`: + +* HOS_SOURCE_ONLY_BLOCKWISE: + No target data is used during adaptation. RPA aligns only the source + subjects, so this should usually be weaker than target-adaptive modes. + +* HOS_UNLABELED_20P: + The first 20% of the held-out target subject is provided without labels + for target-domain alignment. The remaining 80% is evaluated. + +* HOS_UNLABELED_50P: + The first 50% of the held-out target subject is provided without labels + for target-domain alignment. The remaining 50% is evaluated. + +* HOS_LABELED_20P: + The first 20% of the held-out target subject is provided as labeled + calibration data. In this example, RPA ignores the labels and uses only + the covariance distribution for alignment, so HOS_LABELED_20P and + HOS_UNLABELED_20P should give very similar or identical results for the + RPA step. + +The example does not demonstrate HOS_SOURCE_ONLY_TRIALWISE or +HOS_UNLABELED_100P. +""" + +from __future__ import annotations + +import warnings +from typing import Dict, Optional + +import numpy as np +import pandas as pd +from pyriemann.estimation import Covariances +from pyriemann.tangentspace import TangentSpace +from pyriemann.utils.mean import mean_riemann +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.linear_model import LogisticRegression +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler + +from moabb.datasets import Weibo2014 +from moabb.evaluations.cross_subject_target_aware_evaluation import ( + CrossSubjectTargetAwareEvaluation, + CsMode, +) +from moabb.paradigms import LeftRightImagery + + +# --------------------------------------------------------------------- +# SPD helpers +# --------------------------------------------------------------------- +def symmetrize(A: np.ndarray) -> np.ndarray: + return 0.5 * (A + A.T) + + +def nearest_spd_jitter(C: np.ndarray, eps: float = 1e-7) -> np.ndarray: + C = symmetrize(np.asarray(C, dtype=float)) + vals, vecs = np.linalg.eigh(C) + vals = np.maximum(vals, eps) + return symmetrize((vecs * vals) @ vecs.T) + + +def safe_mean_riemann(X: np.ndarray) -> np.ndarray: + X = np.asarray(X) + + if X.ndim != 3 or len(X) == 0: + raise ValueError( + "safe_mean_riemann expects shape (n_matrices, n_channels, n_channels)." + ) + + X = np.stack([nearest_spd_jitter(C) for C in X], axis=0) + return nearest_spd_jitter(mean_riemann(X)) + + +def powerm_spd(C: np.ndarray, power: float, eps: float = 1e-12) -> np.ndarray: + C = nearest_spd_jitter(C) + vals, vecs = np.linalg.eigh(C) + vals = np.maximum(vals, eps) + return symmetrize((vecs * (vals**power)) @ vecs.T) + + +def batch_congruence(X: np.ndarray, A: np.ndarray) -> np.ndarray: + X = np.asarray(X) + + return np.stack([nearest_spd_jitter(A @ C @ A.T) for C in X], axis=0) + + +# --------------------------------------------------------------------- +# Riemannian Procrustes Alignment transformer +# --------------------------------------------------------------------- +class RiemannianProcrustesAlignment(BaseEstimator, TransformerMixin): + """ + Riemannian Procrustes Alignment for covariance matrices. + + Parameters + ---------- + reference_subject : str or None, default="auto" + Reference domain used for alignment. + + - "auto": choose the source subject whose mean covariance is closest + to the global source mean. + - "global" or None: use the global source mean directly. + - any other value: use that source subject as the reference subject. + The value is converted to str and must exist among the source + subjects seen during fit. + + alignment_strength : float, default=1.0 + Strength of the alignment transform. + + - 0.0: no alignment. + - 1.0: full Procrustes-style alignment. + - values between 0.0 and 1.0: partial alignment. + + The transform is computed as: + A = M_ref ** (0.5 * alignment_strength) + @ M_domain ** (-0.5 * alignment_strength) + + use_global_transform_for_unseen : bool, default=True + If True, unseen data without an explicit target transform is aligned + using the transform from the global source mean to the reference mean. + If False, unseen data is returned after SPD cleanup only. + + verbose : bool, default=False + If True, print diagnostic information during fit, including the + selected reference subject, number of source subjects, and whether + target data was used. + """ + + def __init__( + self, + reference_subject: Optional[str] = "auto", + alignment_strength: float = 1.0, + use_global_transform_for_unseen: bool = True, + verbose: bool = False, + ): + self.reference_subject = reference_subject + self.alignment_strength = alignment_strength + self.use_global_transform_for_unseen = use_global_transform_for_unseen + self.verbose = verbose + + def _make_transform(self, M_ref: np.ndarray, M_source: np.ndarray) -> np.ndarray: + alpha = float(self.alignment_strength) + + if alpha == 0.0: + return np.eye(M_ref.shape[0]) + + return powerm_spd(M_ref, 0.5 * alpha) @ powerm_spd(M_source, -0.5 * alpha) + + def fit( + self, + X, + y=None, + subjects=None, + cs_mode=None, + X_target_unlabeled=None, + X_target_labeled=None, + y_target_labeled=None, + **fit_params, + ): + """ + Input + ----- + X : ndarray, shape (n_trials, n_channels, n_channels) + + Fit metadata + ------------ + subjects : array-like + Source subject ID for each source training trial. + + cs_mode : str, optional + Cross-subject evaluation mode name, for example: + "HOS_SOURCE_ONLY_BLOCKWISE", "HOS_UNLABELED_20P", + "HOS_LABELED_20P". + + X_target_unlabeled : ndarray, optional + Unlabeled covariance matrices from the held-out target subject. + If provided, RPA uses them to estimate the target-domain alignment + transform. + + X_target_labeled : ndarray, optional + Labeled covariance matrices from the held-out target subject. + If provided and no unlabeled target data is provided, RPA uses + their covariance distribution for alignment. The labels themselves + are not used. + + y_target_labeled : ndarray, optional + Labels for X_target_labeled. RPA accepts this argument for + evaluator compatibility but does not use the labels directly. + """ + X = np.asarray(X) + + if X.ndim != 3: + raise ValueError( + "RiemannianProcrustesAlignment expects covariance matrices " + f"with shape (n_trials, n_channels, n_channels). Got {X.shape}." + ) + + if X.shape[1] != X.shape[2]: + raise ValueError("Covariance matrices must be square.") + + self.cs_mode_ = cs_mode + self.n_channels_ = int(X.shape[1]) + + if subjects is None: + warnings.warn( + "subjects was not provided to RPA.fit(). " + "Using one global source domain only.", + RuntimeWarning, + ) + subjects = np.array(["source"] * len(X), dtype=str) + else: + subjects = np.asarray(subjects).astype(str) + + if len(subjects) != len(X): + raise ValueError("X and subjects must have the same length.") + + self.source_subjects_ = np.unique(subjects).astype(str) + self.source_means_: Dict[str, np.ndarray] = {} + + for s in self.source_subjects_: + idx = subjects == s + self.source_means_[str(s)] = safe_mean_riemann(X[idx]) + + self.global_source_mean_ = safe_mean_riemann(X) + + # Choose reference domain. + if self.reference_subject is None or self.reference_subject == "global": + self.reference_subject_ = "global" + self.reference_mean_ = self.global_source_mean_ + + elif self.reference_subject == "auto": + distances = {} + + for s, M_s in self.source_means_.items(): + distances[s] = float(np.linalg.norm(M_s - self.global_source_mean_)) + + self.reference_subject_ = min(distances, key=distances.get) + self.reference_mean_ = self.source_means_[self.reference_subject_] + + else: + self.reference_subject_ = str(self.reference_subject) + + if self.reference_subject_ not in self.source_means_: + raise ValueError( + f"reference_subject={self.reference_subject_!r} is not " + f"in source subjects {list(self.source_means_.keys())}." + ) + + self.reference_mean_ = self.source_means_[self.reference_subject_] + + # Source-subject transforms. + self.source_transforms_: Dict[str, np.ndarray] = {} + + for s, M_s in self.source_means_.items(): + self.source_transforms_[s] = self._make_transform(self.reference_mean_, M_s) + + self.global_transform_ = self._make_transform( + self.reference_mean_, self.global_source_mean_ + ) + + # The evaluation mode decides which target data is provided. + # Prefer explicit unlabeled target data when available. Otherwise, + # use labeled target covariance data for alignment, but ignore labels. + X_target_for_alignment = None + target_source_kind = "none" + + if X_target_unlabeled is not None and len(X_target_unlabeled) > 0: + X_target_for_alignment = X_target_unlabeled + target_source_kind = "unlabeled" + + elif X_target_labeled is not None and len(X_target_labeled) > 0: + X_target_for_alignment = X_target_labeled + target_source_kind = "labeled" + + self.target_transform_ = None + self.has_target_data_ = False + self.target_source_kind_ = target_source_kind + + if X_target_for_alignment is not None: + X_target_for_alignment = np.asarray(X_target_for_alignment) + + if X_target_for_alignment.ndim != 3: + raise ValueError( + "Target data for RPA must have shape " + "(n_trials, n_channels, n_channels)." + ) + + if ( + X_target_for_alignment.shape[1] != self.n_channels_ + or X_target_for_alignment.shape[2] != self.n_channels_ + ): + raise ValueError( + "Target covariance matrices have incompatible shape. " + f"Expected ({self.n_channels_}, {self.n_channels_}), " + f"got {X_target_for_alignment.shape[1:]}." + ) + + target_mean = safe_mean_riemann(X_target_for_alignment) + + self.target_transform_ = self._make_transform( + self.reference_mean_, target_mean + ) + self.has_target_data_ = True + + # Stored for sklearn Pipeline fit_transform. + self._fit_subjects_ = subjects.copy() + self._n_fit_samples_ = int(len(X)) + + if self.verbose: + n_unlab = 0 if X_target_unlabeled is None else len(X_target_unlabeled) + n_lab = 0 if X_target_labeled is None else len(X_target_labeled) + + print( + "RPA.fit | " + f"cs_mode={self.cs_mode_}, " + f"reference_subject={self.reference_subject_}, " + f"n_source_subjects={len(self.source_subjects_)}, " + f"alignment_strength={self.alignment_strength}, " + f"has_target_data={self.has_target_data_}, " + f"target_source_kind={self.target_source_kind_}, " + f"n_target_unlabeled={n_unlab}, " + f"n_target_labeled={n_lab}", + flush=True, + ) + + return self + + def transform(self, X, subjects=None): + self._check_is_fitted() + + X = np.asarray(X) + + if X.ndim != 3: + raise ValueError( + "RiemannianProcrustesAlignment expects covariance matrices " + f"with shape (n_trials, n_channels, n_channels). Got {X.shape}." + ) + + if X.shape[1] != self.n_channels_ or X.shape[2] != self.n_channels_: + raise ValueError( + f"Expected matrices with shape ({self.n_channels_}, " + f"{self.n_channels_}), got {X.shape[1:]}." + ) + + # Explicit source-subject transform. + if subjects is not None: + subjects = np.asarray(subjects).astype(str) + + if len(subjects) != len(X): + raise ValueError("X and subjects must have the same length.") + + X_out = np.empty_like(X, dtype=float) + + for s in np.unique(subjects): + s = str(s) + idx = subjects == s + + if s not in self.source_transforms_: + raise ValueError( + f"Unknown source subject {s!r}. Available: " + f"{list(self.source_transforms_.keys())}" + ) + + X_out[idx] = batch_congruence(X[idx], self.source_transforms_[s]) + + return X_out + + # sklearn Pipeline training transform: no subjects are passed to + # transform(), but the length matches the fitted training data. + if hasattr(self, "_fit_subjects_") and len(X) == self._n_fit_samples_: + X_out = np.empty_like(X, dtype=float) + + for s in np.unique(self._fit_subjects_): + s = str(s) + idx = self._fit_subjects_ == s + X_out[idx] = batch_congruence(X[idx], self.source_transforms_[s]) + + return X_out + + # Test / unseen data. + if self.target_transform_ is not None: + return batch_congruence(X, self.target_transform_) + + if self.use_global_transform_for_unseen: + return batch_congruence(X, self.global_transform_) + + return np.stack([nearest_spd_jitter(C) for C in X], axis=0) + + def fit_transform( + self, + X, + y=None, + subjects=None, + cs_mode=None, + X_target_unlabeled=None, + X_target_labeled=None, + y_target_labeled=None, + **fit_params, + ): + return self.fit( + X, + y=y, + subjects=subjects, + cs_mode=cs_mode, + X_target_unlabeled=X_target_unlabeled, + X_target_labeled=X_target_labeled, + y_target_labeled=y_target_labeled, + **fit_params, + ).transform(X, subjects=subjects) + + def _check_is_fitted(self): + if not hasattr(self, "reference_mean_"): + raise RuntimeError("RiemannianProcrustesAlignment is not fitted.") + + +# --------------------------------------------------------------------- +# Demo pipelines +# --------------------------------------------------------------------- + + +def make_pipelines(): + ts_lr = make_pipeline( + Covariances(estimator="oas"), + TangentSpace(metric="riemann"), + StandardScaler(), + LogisticRegression( + C=1.0, class_weight="balanced", max_iter=5000, random_state=42 + ), + ) + + rpa_ts_lr = make_pipeline( + Covariances(estimator="oas"), + RiemannianProcrustesAlignment( + reference_subject="auto", + alignment_strength=1.0, + use_global_transform_for_unseen=True, + verbose=True, + ), + TangentSpace(metric="riemann"), + StandardScaler(), + LogisticRegression( + C=1.0, class_weight="balanced", max_iter=5000, random_state=42 + ), + ) + + return {"TS + LR": ts_lr, "RPA + TS + LR": rpa_ts_lr} + + +# --------------------------------------------------------------------- +# Result summaries +# --------------------------------------------------------------------- +def normalize_results(results: pd.DataFrame) -> pd.DataFrame: + results = results.copy() + + if "dataset" in results.columns: + results["dataset"] = results["dataset"].apply( + lambda d: d.code if hasattr(d, "code") else str(d) + ) + + return results + + +def summarize_per_dataset_pipeline(results: pd.DataFrame) -> pd.DataFrame: + """ + One row per dataset and pipeline. + """ + results = normalize_results(results) + + summary = ( + results.groupby(["dataset", "pipeline"], as_index=False) + .agg( + n_folds=("score", "count"), + mean_ROC_AUC=("score", "mean"), + std_ROC_AUC=("score", "std"), + ) + .sort_values(["dataset", "mean_ROC_AUC"], ascending=[True, False]) + .reset_index(drop=True) + ) + + return summary + + +def summarize_global_pipeline(summary: pd.DataFrame) -> pd.DataFrame: + """ + One row per pipeline. + + The global mean is computed over dataset means, not over all folds. + This avoids giving more weight to datasets with more subjects/sessions. + """ + global_summary = ( + summary.groupby("pipeline", as_index=False) + .agg( + n_datasets=("dataset", "nunique"), + total_folds=("n_folds", "sum"), + mean_ROC_AUC_over_datasets=("mean_ROC_AUC", "mean"), + std_ROC_AUC_over_datasets=("mean_ROC_AUC", "std"), + ) + .sort_values("mean_ROC_AUC_over_datasets", ascending=False) + .reset_index(drop=True) + ) + + return global_summary + + +def make_dataset_pipeline_table(summary: pd.DataFrame) -> pd.DataFrame: + """ + Wide table: rows are datasets, columns are pipelines. + Useful for quick visual comparison. + """ + table = summary.pivot( + index="dataset", columns="pipeline", values="mean_ROC_AUC" + ).reset_index() + + table.columns.name = None + return table + + +# --------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------- + + +def main(): + dataset = Weibo2014() # BNCI2014_001() + + # Uncomment for a fast smoke test. + # dataset.subject_list = [1, 2, 3, 4, 5] + + datasets = [dataset] + + paradigm = LeftRightImagery(fmin=8.0, fmax=32.0, resample=128) + + pipelines = make_pipelines() + + # Good RPA examples: + # + # CsMode.HOS_UNLABELED_20P: + # First 20% of the held-out target subject is used without labels + # for target alignment. Remaining 80% is evaluated. + # + # CsMode.HOS_LABELED_20P: + # First 20% of the held-out target subject is passed separately as + # labeled calibration data. RPA may use its covariance distribution + # for alignment; a final classifier may use the labels if it supports + # X_target_labeled/y_target_labeled. + # + # CsMode.HOS_SOURCE_ONLY_BLOCKWISE: + # No target adaptation. This is closest to standard MOABB + # cross-subject block prediction. + # + cs_mode = ( + CsMode.HOS_UNLABELED_20P + ) # RPA aligns using the first 20% target data without labels. + # cs_mode = CsMode.HOS_UNLABELED_50P # RPA aligns using the first 20% target data without labels. + # cs_mode = CsMode.HOS_LABELED_20P # RPA aligns using the first 20% target data, ignoring labels. + # cs_mode = CsMode.HOS_SOURCE_ONLY_BLOCKWISE # RPA uses source-subject alignment only; no target adaptation. + + evaluation = CrossSubjectTargetAwareEvaluation( + paradigm=paradigm, + datasets=datasets, + cs_mode=cs_mode, + n_jobs=4, # 1 while debugging verbose RPA output + overwrite=True, + random_state=42, + ) + + results = evaluation.process(pipelines=pipelines) + results = normalize_results(results) + + useful_cols = ["dataset", "subject", "session", "pipeline", "score"] + + useful_cols = [c for c in useful_cols if c in results.columns] + + print("\nRaw results:") + print(results[useful_cols].to_string(index=False)) + + summary = summarize_per_dataset_pipeline(results) + + print("\nPer-dataset / per-pipeline summary:") + print(summary.to_string(index=False)) + + global_summary = summarize_global_pipeline(summary) + + print("\nGlobal per-pipeline summary:") + print(global_summary.to_string(index=False)) + + +if __name__ == "__main__": + main() diff --git a/moabb/evaluations/cross_subject_target_aware_evaluation.py b/moabb/evaluations/cross_subject_target_aware_evaluation.py new file mode 100644 index 000000000..e2858e276 --- /dev/null +++ b/moabb/evaluations/cross_subject_target_aware_evaluation.py @@ -0,0 +1,809 @@ +""" +A cross-subject implementation specifically targeting transfer learning. + +It provides a leave-one-subject-out (LOSO) evaluation with controlled access +to the held-out subject. The held-out subject is the target subject being +evaluated. The code allows users to compare results in the domain of transfer +learning BCI. + +Currently, it provides 6 modes: + +HOS_SOURCE_ONLY_TRIALWISE - Source-only training; each held-out target + trial is predicted independently. This mode is + the closest to real-world BCI when we do not + have enough data. + +HOS_SOURCE_ONLY_BLOCKWISE - Source-only training; held-out target trials + are predicted as a block, matching standard + MOABB behavior. It is a compatibility mode + with CrossSubjectEvaluation. Do not use it; + use one of the other modes, as they are more + clearly defined. + +HOS_UNLABELED_20P - First 20% of held-out target trials are used unlabeled + for adaptation; the remaining 80% are evaluated. + +HOS_UNLABELED_50P - First 50% of held-out target trials are used unlabeled + for adaptation; the remaining 50% are evaluated. + +HOS_UNLABELED_100P - All held-out target trials are used unlabeled for + transductive adaptation and are also evaluated. + +HOS_LABELED_20P - First 20% of held-out target trials are used with labels + for supervised calibration; the remaining 80% are + evaluated. + +Important notes: + +1) HOS_SOURCE_ONLY_BLOCKWISE vs HOS_UNLABELED_100P + + HOS_SOURCE_ONLY_BLOCKWISE predicts the target-subject samples as a + block after source-only training. The held-out subject is not provided + during training. + + By contrast, HOS_UNLABELED_100P explicitly provides the entire + unlabeled target block during fit/adaptation before predicting it. + Therefore, HOS_SOURCE_ONLY_BLOCKWISE should be used with care, as it + can be misused if a pipeline delays training and uses the held-out + subject as training data, as in HOS_UNLABELED_100P. + +2) Labeled target data is not provided to old pipelines + + In HOS_LABELED_20P, target labeled data is only passed if the estimator + accepts X_target_labeled / y_target_labeled. This means that old/regular + pipelines continue not to receive any data from the held-out subject. + +3) Unlabeled target data also only works for special estimators + + HOS_UNLABEnbLED_20P/50P/100P modes work only if a pipeline step explicitly + accepts X_target_unlabeled. Regular pipelines will silently ignore the + target adaptation data, as they are unaware of it. + +4) Split order defines adaptation data + The “first 20%” depends on trial order. + +""" + +from __future__ import annotations + +import inspect +import time +import warnings +from contextlib import contextmanager +from enum import Enum, auto +from typing import Any, Optional + +import joblib +import numpy as np +from sklearn.base import BaseEstimator, ClassifierMixin, clone +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import LabelEncoder +from tqdm import tqdm + +from moabb.evaluations import CrossSubjectEvaluation, CrossSubjectSplitter +from moabb.evaluations.utils import _create_scorer, _ensure_fitted + + +@contextmanager +def tqdm_joblib(tqdm_object): + """ + Context manager to patch joblib to report into tqdm on batch completion. + """ + + class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): + def __call__(self, *args, **kwargs): + tqdm_object.update(n=self.batch_size) + return super().__call__(*args, **kwargs) + + old_callback = joblib.parallel.BatchCompletionCallBack + joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback + + try: + yield tqdm_object + finally: + joblib.parallel.BatchCompletionCallBack = old_callback + tqdm_object.close() + + +# --------------------------------------------------------------------- +# Target-aware cross-subject modes +# --------------------------------------------------------------------- + + +class CsMode(Enum): + HOS_SOURCE_ONLY_TRIALWISE = ( + auto() + ) # Source-only training; each held-out target trial is predicted independently. + HOS_SOURCE_ONLY_BLOCKWISE = auto() # Source-only training; held-out target trials are predicted as a block, matching standard MOABB behavior. + HOS_UNLABELED_20P = auto() # First 20% of held-out target trials are used unlabeled for adaptation; remaining 80% are evaluated. + HOS_UNLABELED_50P = auto() # First 50% of held-out target trials are used unlabeled for adaptation; remaining 50% are evaluated. + HOS_UNLABELED_100P = auto() # All held-out target trials are used unlabeled for transductive adaptation and also evaluated. + HOS_LABELED_20P = auto() # First 20% of held-out target trials are used with labels for supervised calibration; remaining 80% are evaluated. + + +def cs_mode_uses_unlabeled_target(mode: CsMode) -> bool: + return mode in { + CsMode.HOS_UNLABELED_20P, + CsMode.HOS_UNLABELED_50P, + CsMode.HOS_UNLABELED_100P, + } + + +def cs_mode_uses_labeled_target(mode: CsMode) -> bool: + return mode == CsMode.HOS_LABELED_20P + + +def cs_mode_target_fraction(mode: CsMode) -> float: + if mode in {CsMode.HOS_SOURCE_ONLY_TRIALWISE, CsMode.HOS_SOURCE_ONLY_BLOCKWISE}: + return 0.0 + + if mode == CsMode.HOS_UNLABELED_20P: + return 0.20 + + if mode == CsMode.HOS_UNLABELED_50P: + return 0.50 + + if mode == CsMode.HOS_UNLABELED_100P: + return 1.00 + + if mode == CsMode.HOS_LABELED_20P: + return 0.20 + + raise ValueError(f"Unknown CsMode: {mode!r}") + + +def split_target_for_cs_mode( + test_idx: np.ndarray, cs_mode: CsMode +) -> tuple[np.ndarray, np.ndarray]: + test_idx = np.asarray(test_idx, dtype=int) + + if len(test_idx) == 0: + raise ValueError("Empty held-out test index.") + + fraction = cs_mode_target_fraction(cs_mode) + + if fraction == 0.0: + return np.array([], dtype=int), test_idx + + if fraction == 1.0: + return test_idx, test_idx + + n_adapt = int(round(fraction * len(test_idx))) + n_adapt = max(1, n_adapt) + n_adapt = min(n_adapt, len(test_idx) - 1) + + target_adapt_idx = test_idx[:n_adapt] + eval_idx = test_idx[n_adapt:] + + return target_adapt_idx, eval_idx + + +# --------------------------------------------------------------------- +# Metadata-aware fitting utilities +# --------------------------------------------------------------------- + + +def estimator_accepts_argument(estimator: Any, method_name: str, arg_name: str) -> bool: + """ + Return True if an estimator method explicitly accepts a metadata argument. + + This is the only compatibility layer kept here: metadata-aware estimators + may accept subjects, cs_mode, X_target_unlabeled, X_target_labeled, or + y_target_labeled, while normal sklearn estimators usually do not. + """ + method = getattr(estimator, method_name, None) + + if method is None: + return False + + try: + signature = inspect.signature(method) + except (TypeError, ValueError): + return False + + return arg_name in signature.parameters + + +def _safe_transform(step: Any, X: Any): + if X is None: + return None + + if not hasattr(step, "transform"): + return X + + return step.transform(X) + + +def _fit_transform_step_with_metadata( + step: Any, + X: Any, + y: np.ndarray, + subjects: Optional[np.ndarray] = None, + cs_mode: Optional[str] = None, + X_target_unlabeled: Optional[Any] = None, + X_target_labeled: Optional[Any] = None, + y_target_labeled: Optional[np.ndarray] = None, +): + fit_kwargs = {} + + if subjects is not None and estimator_accepts_argument(step, "fit", "subjects"): + fit_kwargs["subjects"] = subjects + + if cs_mode is not None and estimator_accepts_argument(step, "fit", "cs_mode"): + fit_kwargs["cs_mode"] = cs_mode + + if X_target_unlabeled is not None and estimator_accepts_argument( + step, "fit", "X_target_unlabeled" + ): + fit_kwargs["X_target_unlabeled"] = X_target_unlabeled + + if X_target_labeled is not None and estimator_accepts_argument( + step, "fit", "X_target_labeled" + ): + fit_kwargs["X_target_labeled"] = X_target_labeled + + if y_target_labeled is not None and estimator_accepts_argument( + step, "fit", "y_target_labeled" + ): + fit_kwargs["y_target_labeled"] = y_target_labeled + + # Use fit_transform only for ordinary sklearn-style steps without metadata. + # When metadata is passed, call fit() and transform() explicitly because + # fit_transform() may be implemented separately and may not accept the same + # metadata arguments as fit(). + if hasattr(step, "fit_transform") and not fit_kwargs: + Xt = step.fit_transform(X, y) + else: + step.fit(X, y, **fit_kwargs) + + if not hasattr(step, "transform"): + raise TypeError( + f"Intermediate pipeline step {step!r} has no transform method." + ) + + Xt = step.transform(X) + + Xt_target_unlabeled = _safe_transform(step, X_target_unlabeled) + Xt_target_labeled = _safe_transform(step, X_target_labeled) + + return Xt, Xt_target_unlabeled, Xt_target_labeled + + +def fit_pipeline_with_subject_metadata( + estimator: Any, + X_train: Any, + y_train: np.ndarray, + subjects_train: Optional[np.ndarray] = None, + cs_mode: Optional[str] = None, + X_target_unlabeled: Optional[Any] = None, + X_target_labeled: Optional[Any] = None, + y_target_labeled: Optional[np.ndarray] = None, +) -> Any: + """ + Fit an estimator or sklearn Pipeline while passing metadata only to steps + that explicitly support it. + + Normal sklearn estimators remain supported because unsupported metadata + arguments are not passed to them. + """ + if not isinstance(estimator, Pipeline): + fit_kwargs = {} + + if subjects_train is not None and estimator_accepts_argument( + estimator, "fit", "subjects" + ): + fit_kwargs["subjects"] = subjects_train + + if cs_mode is not None and estimator_accepts_argument( + estimator, "fit", "cs_mode" + ): + fit_kwargs["cs_mode"] = cs_mode + + if X_target_unlabeled is not None and estimator_accepts_argument( + estimator, "fit", "X_target_unlabeled" + ): + fit_kwargs["X_target_unlabeled"] = X_target_unlabeled + + if X_target_labeled is not None and estimator_accepts_argument( + estimator, "fit", "X_target_labeled" + ): + fit_kwargs["X_target_labeled"] = X_target_labeled + + if y_target_labeled is not None and estimator_accepts_argument( + estimator, "fit", "y_target_labeled" + ): + fit_kwargs["y_target_labeled"] = y_target_labeled + + estimator.fit(X_train, y_train, **fit_kwargs) + return estimator + + Xt_train = X_train + Xt_target_unlabeled = X_target_unlabeled + Xt_target_labeled = X_target_labeled + + for _step_name, step in estimator.steps[:-1]: + (Xt_train, Xt_target_unlabeled, Xt_target_labeled) = ( + _fit_transform_step_with_metadata( + step=step, + X=Xt_train, + y=y_train, + subjects=subjects_train, + cs_mode=cs_mode, + X_target_unlabeled=Xt_target_unlabeled, + X_target_labeled=Xt_target_labeled, + y_target_labeled=y_target_labeled, + ) + ) + + _final_name, final_step = estimator.steps[-1] + + final_fit_kwargs = {} + + if subjects_train is not None and estimator_accepts_argument( + final_step, "fit", "subjects" + ): + final_fit_kwargs["subjects"] = subjects_train + + if cs_mode is not None and estimator_accepts_argument(final_step, "fit", "cs_mode"): + final_fit_kwargs["cs_mode"] = cs_mode + + if Xt_target_unlabeled is not None and estimator_accepts_argument( + final_step, "fit", "X_target_unlabeled" + ): + final_fit_kwargs["X_target_unlabeled"] = Xt_target_unlabeled + + if Xt_target_labeled is not None and estimator_accepts_argument( + final_step, "fit", "X_target_labeled" + ): + final_fit_kwargs["X_target_labeled"] = Xt_target_labeled + + if y_target_labeled is not None and estimator_accepts_argument( + final_step, "fit", "y_target_labeled" + ): + final_fit_kwargs["y_target_labeled"] = y_target_labeled + + final_step.fit(Xt_train, y_train, **final_fit_kwargs) + + return estimator + + +class TrialwisePredictWrapper(ClassifierMixin, BaseEstimator): + """ + Wrap an already-fitted estimator and force one-trial-at-a-time prediction. + + This keeps MOABB/sklearn scorer compatibility while preventing the wrapped + estimator from receiving the full target test block during prediction. + + Has multi-class support. + """ + + _estimator_type = "classifier" + + def __init__(self, fitted_estimator): + self.fitted_estimator = fitted_estimator + _ensure_fitted(fitted_estimator) + self.classes_ = self._get_classes(fitted_estimator) + + def fit(self, X, y=None): + raise RuntimeError("TrialwisePredictWrapper wraps an already fitted estimator.") + + def predict(self, X): + return np.asarray( + [ + self.fitted_estimator.predict(self._slice_one(X, i))[0] + for i in range(len(X)) + ] + ) + + def predict_proba(self, X): + if not hasattr(self.fitted_estimator, "predict_proba"): + raise AttributeError("Wrapped estimator does not provide predict_proba.") + + rows = [ + self._first_row(self.fitted_estimator.predict_proba(self._slice_one(X, i))) + for i in range(len(X)) + ] + + return np.vstack(rows) + + def decision_function(self, X): + if not hasattr(self.fitted_estimator, "decision_function"): + raise AttributeError("Wrapped estimator does not provide decision_function.") + + rows = [ + self._first_row( + self.fitted_estimator.decision_function(self._slice_one(X, i)) + ) + for i in range(len(X)) + ] + + out = np.asarray(rows) + + # sklearn convention: + # binary decision_function -> shape (n_samples,) + # multiclass decision_function -> shape (n_samples, n_classes) + if out.ndim == 2 and out.shape[1] == 1: + return out.ravel() + + return out + + @staticmethod + def _slice_one(X, i): + return X[i : i + 1] + + @staticmethod + def _first_row(output): + arr = np.asarray(output) + + if arr.ndim == 0: + return arr.item() + + if arr.shape[0] == 1: + arr = arr[0] + + return arr + + @staticmethod + def _get_classes(estimator): + if hasattr(estimator, "classes_"): + return estimator.classes_ + + if hasattr(estimator, "steps"): + final_estimator = estimator.steps[-1][1] + if hasattr(final_estimator, "classes_"): + return final_estimator.classes_ + + raise AttributeError("Wrapped estimator does not expose classes_.") + + +# --------------------------------------------------------------------- +# Main evaluation class +# --------------------------------------------------------------------- + + +class CrossSubjectTargetAwareEvaluation(CrossSubjectEvaluation): + _eval_type = "CrossSubjectTargetAware" + _score_per_session = True + _needs_all_subjects = True + + def __init__( + self, *args, cs_mode: CsMode = CsMode.HOS_SOURCE_ONLY_BLOCKWISE, **kwargs + ): + super().__init__(*args, **kwargs) + + if not isinstance(cs_mode, CsMode): + raise ValueError(f"cs_mode must be an instance of CsMode. Got {cs_mode!r}.") + + self.cs_mode = cs_mode + + def _build_task_list(self, dataset, y, metadata, splitter, work_plan, param_grid): + """ + Build lightweight flattened MOABB 1.6 tasks with target-aware split + metadata. + + The task dictionary intentionally does not store X, y, metadata, groups, + or sessions. These are passed separately to the worker to avoid duplicating + large objects in every joblib task. + """ + groups = metadata["subject"].values + + tasks = [] + + for cv_ind, (train_idx, test_idx) in enumerate(splitter.split(y, metadata)): + train_idx = np.asarray(train_idx, dtype=int) + test_idx = np.asarray(test_idx, dtype=int) + + subject = groups[test_idx[0]] + + if subject in work_plan: + subject_key = subject + elif str(subject) in work_plan: + subject_key = str(subject) + else: + continue + + target_adapt_idx, eval_idx = split_target_for_cs_mode(test_idx, self.cs_mode) + + if len(eval_idx) == 0: + warnings.warn( + f"{dataset.code} | subject={subject}: empty evaluation set " + f"for cs_mode={self.cs_mode.name}. Skipping fold.", + RuntimeWarning, + ) + continue + + for pipeline_name, pipeline in work_plan[subject_key].items(): + tasks.append( + { + "dataset": dataset, + "train_idx": train_idx, + "test_idx": test_idx, + "target_adapt_idx": target_adapt_idx, + "eval_idx": eval_idx, + "subject": subject, + "pipeline_name": pipeline_name, + "pipeline": pipeline, + "param_grid": param_grid, + "cv_ind": cv_ind, + } + ) + + return tasks + + def _create_splitter(self): + """ + Create the MOABB 1.6 cross-subject splitter. + + This delegates subject-level CV handling to MOABB's CrossSubjectSplitter, + so cv_class, random_state, and cv_kwargs keep the same meaning as in + CrossSubjectEvaluation. + """ + cv_kwargs = getattr(self, "cv_kwargs", {}) or {} + cv_class = getattr(self, "cv_class", None) + + if cv_class is None: + return CrossSubjectSplitter(random_state=self.random_state, **cv_kwargs) + + return CrossSubjectSplitter( + cv_class=cv_class, random_state=self.random_state, **cv_kwargs + ) + + def _evaluate_task(self, task, X, y, groups, sessions): + """ + Evaluate one flattened target-aware task. + + Large shared objects are passed as worker arguments instead of being stored + inside each task dictionary. + """ + dataset = task["dataset"] + + train_idx = task["train_idx"] + test_idx = task["test_idx"] + target_adapt_idx = task["target_adapt_idx"] + eval_idx = task["eval_idx"] + + subject = task["subject"] + name = task["pipeline_name"] + clf = task["pipeline"] + param_grid = task["param_grid"] + cv_ind = task["cv_ind"] + + if param_grid is not None: + raise NotImplementedError( + "param_grid/grid search is not supported by " + "CrossSubjectTargetAwareEvaluation yet. Inner GridSearchCV does " + "not pass subjects, X_target_unlabeled, X_target_labeled, or " + "y_target_labeled to inner fits. Please set param_grid=None." + ) + + nchan = self._get_nchan(X) + + cvclf = clone(clf) + + X_train = X[train_idx] + + if self.mne_labels: + y_train = y[train_idx] + y_eval_all = y + else: + fold_label_idx = np.unique(np.concatenate([train_idx, test_idx])) + + le = LabelEncoder() + le.fit(y[fold_label_idx]) + + y_train = le.transform(y[train_idx]) + + y_eval_all = np.empty_like(y, dtype=int) + y_eval_all[fold_label_idx] = le.transform(y[fold_label_idx]) + + subjects_train = groups[train_idx] + + X_target_unlabeled = None + X_target_labeled = None + y_target_labeled = None + + n_target_unlabeled = 0 + n_target_labeled = 0 + + if cs_mode_uses_unlabeled_target(self.cs_mode): + X_target_unlabeled = X[target_adapt_idx] + n_target_unlabeled = int(len(target_adapt_idx)) + + elif cs_mode_uses_labeled_target(self.cs_mode): + X_target_labeled = X[target_adapt_idx] + y_target_labeled = y_eval_all[target_adapt_idx] + n_target_labeled = int(len(target_adapt_idx)) + + duration = self._fit_estimator_with_target_metadata( + estimator=cvclf, + X_train=X_train, + y_train=y_train, + subjects_train=subjects_train, + cs_mode=self.cs_mode.name, + X_target_unlabeled=X_target_unlabeled, + X_target_labeled=X_target_labeled, + y_target_labeled=y_target_labeled, + ) + + self._maybe_save_model_cv( + cvclf, dataset, subject, "", name, cv_ind, eval_type=self._eval_type + ) + + if self.cs_mode == CsMode.HOS_SOURCE_ONLY_TRIALWISE: + scoring_estimator = TrialwisePredictWrapper(cvclf) + else: + scoring_estimator = cvclf + + scorer = _create_scorer(scoring_estimator, self.paradigm.scoring) + + results = [] + + for session in np.unique(sessions[eval_idx]): + session_eval_idx = eval_idx[sessions[eval_idx] == session] + + if len(session_eval_idx) == 0: + continue + + res = self._build_scored_result( + dataset=dataset, + subject=subject, + session=session, + pipeline=name, + n_samples=len(X_train), + n_channels=nchan, + duration=duration, + scorer=scorer, + model=scoring_estimator, + X_test=X[session_eval_idx], + y_test=y_eval_all[session_eval_idx], + ) + + res["cs_mode"] = self.cs_mode.name + res["n_source_train"] = int(len(train_idx)) + res["n_source_fit"] = int(len(X_train)) + res["n_train_total"] = int(len(X_train) + n_target_labeled) + res["n_heldout_total"] = int(len(test_idx)) + res["n_target_adapt"] = int(len(target_adapt_idx)) + res["n_target_eval"] = int(len(eval_idx)) + res["n_target_unlabeled"] = int(n_target_unlabeled) + res["n_target_labeled"] = int(n_target_labeled) + res["target_subject"] = subject + + if self.cs_mode == CsMode.HOS_SOURCE_ONLY_TRIALWISE: + res["predict_mode"] = "trialwise" + else: + res["predict_mode"] = "blockwise" + + results.append(res) + + return results + + def _evaluate_parallel_dataset( + self, + dataset, + pipelines, + param_grid, + process_pipeline, + postprocess_pipeline, + work_plan, + ): + """ + MOABB > 1.6-native flattened parallel evaluation. + + One task = one held-out subject fold x one pipeline. + + The task dictionaries are kept lightweight. Large shared objects are passed + separately to the worker. + """ + from joblib import Parallel, delayed + + if param_grid is not None: + raise NotImplementedError( + "param_grid/grid search is not supported by " + "CrossSubjectTargetAwareEvaluation yet. Inner GridSearchCV does " + "not pass subjects, X_target_unlabeled, X_target_labeled, or " + "y_target_labeled to inner fits. Please set param_grid=None." + ) + + subjects_to_load = ( + dataset.subject_list + if getattr(self, "_needs_all_subjects", False) + else list(work_plan.keys()) + ) + + run_pipes = { + name: pipe + for subject_pipelines in work_plan.values() + for name, pipe in subject_pipelines.items() + } + + X, y_raw, metadata = self._load_data( + dataset=dataset, + run_pipes=run_pipes, + process_pipeline=process_pipeline, + postprocess_pipeline=postprocess_pipeline, + subjects=subjects_to_load, + ) + + y = np.asarray(y_raw) + + groups = metadata["subject"].values + sessions = metadata["session"].values + + splitter = self._create_splitter() + + tasks = self._build_task_list( + dataset=dataset, + y=y, + metadata=metadata, + splitter=splitter, + work_plan=work_plan, + param_grid=param_grid, + ) + + if not tasks: + return [] + + desc = f"{dataset.code}-{self._eval_type}" + + if self.n_jobs == 1: + nested_results = [] + + for task in tqdm( + tasks, total=len(tasks), desc=desc, unit="task", dynamic_ncols=True + ): + nested_results.append( + self._evaluate_task( + task=task, X=X, y=y, groups=groups, sessions=sessions + ) + ) + + else: + with tqdm_joblib( + tqdm(total=len(tasks), desc=desc, unit="task", dynamic_ncols=True) + ): + nested_results = Parallel(n_jobs=self.n_jobs, verbose=0)( + delayed(self._evaluate_task)( + task=task, X=X, y=y, groups=groups, sessions=sessions + ) + for task in tasks + ) + + all_results = [] + + for rows in nested_results: + all_results.extend(rows) + + return all_results + + def _fit_estimator_with_target_metadata( + self, + estimator: Any, + X_train: Any, + y_train: np.ndarray, + subjects_train: Optional[np.ndarray] = None, + cs_mode: Optional[str] = None, + X_target_unlabeled: Optional[Any] = None, + X_target_labeled: Optional[Any] = None, + y_target_labeled: Optional[np.ndarray] = None, + ): + """ + Fit one estimator with optional target-aware metadata. + """ + start_time = time.time() + + fit_pipeline_with_subject_metadata( + estimator=estimator, + X_train=X_train, + y_train=y_train, + subjects_train=subjects_train, + cs_mode=cs_mode, + X_target_unlabeled=X_target_unlabeled, + X_target_labeled=X_target_labeled, + y_target_labeled=y_target_labeled, + ) + + _ensure_fitted(estimator) + + duration = time.time() - start_time + return duration