Skip to content

Latest commit

 

History

History
278 lines (215 loc) · 8.74 KB

File metadata and controls

278 lines (215 loc) · 8.74 KB

NeuroTTT

Bridging Pretrain–Downstream Task Misalignment in EEG Foundation Models via Test-Time Training

Paper License Python PyTorch

🎯 Overview

NeuroTTT is a novel framework that bridges the gap between pretrained EEG foundation models and downstream tasks through Test-Time Training (TTT). Our approach addresses the fundamental challenge of domain misalignment in EEG foundation models by introducing:

  • Domain-specific self-supervised fine-tuning that augments foundation models with task-relevant objectives
  • Test-time training for individual unlabeled test samples during inference
  • Prediction entropy minimization (Tent) for continual model calibration

The framework integrates multiple state-of-the-art components:

  • CBraMod: A Criss-Cross Brain Foundation Model for EEG decoding
  • Tent: Fully test-time adaptation by entropy minimization
  • Multiple pretext tasks: Band filtering, temporal ordering, channel masking, and more

🔨 Installation | 🚀 Quick Start | 📊 Datasets | 🏋️ Training | 🧪 Test-Time Adaptation | 📖 Documentation | 🔗 Citation

🔗 Citation

If you use NeuroTTT in your research, please cite:

@misc{wang2025neurotttbridgingpretrainingdownstreamtask,
      title={NeuroTTT: Bridging Pretraining-Downstream Task Misalignment in EEG Foundation Models via Test-Time Training}, 
      author={Suli Wang and Yangshen Deng and Zhenghua Bao and Xinyu Zhan and Yiqun Duan},
      year={2025},
      eprint={2509.26301},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2509.26301}, 
}

✨ Key Features

  • 🧠 Foundation Model Integration: Built on CBraMod, a state-of-the-art EEG foundation model
  • 🔄 Test-Time Training: Self-supervised adaptation during inference without labeled data
  • 🎯 Multiple Pretext Tasks: Band filtering, temporal ordering, channel masking, phase prediction
  • 📊 Comprehensive Dataset Support: Multiple EEG datasets across three domains
  • ⚡ Efficient Implementation: Optimized for both research and practical applications
  • 🔧 Flexible Configuration: UV-based configuration system for easy experimentation

🔨 Installation

Prerequisites

  • Python 3.10+
  • CUDA-compatible GPU (recommended)
  • uv package manager

Quick Installation

  1. Install uv package manager:
curl -LsSf https://astral.sh/uv/install.sh | sh
  1. Clone the repository:
git clone https://github.com/wsl2000/NeuroTTT.git
cd NeuroTTT
  1. Initialize environment:
bash init_env.sh
  1. Activate environment:
source .venv/bin/activate

🚀 Quick Start

Fine-tuning Example

DATASET_PROCESSED_SPEECH="${DATASET_PROCESSED_ROOT}BCIC2020-3/"
PRETRAINED_WEIGHTS="./CBraMod/pretrained_weights/pretrained_weights.pth"
MODEL_WEIGHTS_ROOT="./model_weights/BCIC2020-3/"
SPLIT_CONFIG="./configs/speech_dataset_default.yaml"

# Fine-tune on BCIC2020-3 dataset
python CBraMod/finetune_main.py \
    --epochs 50 \
    --cuda 0 \
    --seed 8888 \
    --batch_size 64 \
    --lr 5e-4 \
    --split_config ${SPLIT_CONFIG} \
    --multi_lr 0 \
    --weight_decay 5e-2 \
    --dropout 0.1 \
    \
    --downstream_dataset BCIC2020-3 \
    --datasets_dir ${DATASET_PROCESSED_SPEECH} \
    --num_of_classes 5 \
    --model_dir ${MODEL_WEIGHTS_ROOT} \
    --use_pretrained_weights True \
    --foundation_dir ${PRETRAINED_WEIGHTS} \
    --classifier all_patch_reps \
    --pretext none

📊 Supported Datasets

NeuroTTT provides built-in support for three EEG datasets. You can extend it to additional datasets, and we’ve retained CBraMod’s original dataset support for convenience:

NeuroTTT

  • BCIC2020-3: Imagined Speech Classification (5 classes)
  • BCIC-IV-2a: Motor imagery classification (4 classes)
  • MentalArithmetic: Mental workload (binary classification)

Clinical Applications

  • CHB-MIT: Seizure detection (binary classification)
  • TUAB: Abnormal EEG detection (binary classification)
  • TUEV: EEG evaluation (multi-class)
  • ISRUC: Sleep stage classification (5 stages)

Cognitive & Emotional States

  • SEED-V: Emotion recognition (5 emotions)
  • SEED-VIG: Vigilance estimation (regression)
  • Faced: Face processing (multi-class)
  • Mumtaz2016: Mental state classification (binary)

Stress & Workload

  • Stress Dataset: Stress level classification

🏋️ Training

1. Fine-tuning with Pretext Tasks

Fine-tune with self-supervised pretext tasks:

# With band filtering pretext task
python CBraMod/finetune_main.py \
    --downstream_dataset MentalArithmetic \
    --pretext band \
    --pretext_weight_band 0.1 \
    --epochs 20 \
    --lr 1e-4

# With multiple pretext tasks
python CBraMod/finetune_main.py \
    --downstream_dataset BCIC-IV-2a \
    --pretext all \
    --pretext_weight_band 0.2 \
    --pretext_weight_temporal 0.6 \
    --epochs 20

3. Available Pretext Tasks

  • band: Frequency band filtering and reconstruction
  • temporal: Temporal order prediction
  • channel: Channel masking and reconstruction
  • phase: Phase prediction tasks
  • reverse: Reverse sequence prediction
  • all: Combination of multiple pretext tasks

🧪 Test-Time Adaptation

Self-Supervised Test-Time Training

Adapt the model to individual test samples using self-supervised objectives:

python CBraMod/ttt_main.py \
    --downstream_dataset BCIC2020-3 \
    --model_path ./model_weights/BCIC2020-3/best_model.pth \
    --ttt_lr 1e-4 \
    --ttt_steps 5 \
    --pretext band \
    --split_config ./configs/speech_dataset_default.yaml

Entropy Minimization (Tent)

Apply Tent for continual adaptation during inference:

python CBraMod/tta_main.py \
    --test_time_method tent \
    --downstream_dataset MentalArithmetic \
    --model_path ./model_weights/MentalArithmetic/best_model.pth \
    --lr 1e-3 \
    --steps 1 \
    --episodic

Test-Time Methods

  • source: No adaptation (baseline)
  • tent: Entropy minimization adaptation
  • norm: Batch normalization statistics update

📁 Project Structure

NeuroTTT/
├── CBraMod/                    # CBraMod foundation model
│   ├── models/                 # Model architectures
│   ├── datasets/               # Dataset loaders
│   ├── preprocessing/          # Data preprocessing scripts
│   ├── pretrained_weights/     # Pretrained model weights
│   ├── finetune_main.py       # Fine-tuning script
│   ├── ttt_main.py            # Test-time training script
│   └── tta_main.py            # Test-time adaptation script
├── tent/                       # Tent implementation
│   ├── tent.py                # Core Tent algorithm
│   ├── norm.py                # Normalization methods
│   └── cifar10c.py            # Example usage
├── configs/                    # Configuration files
│   ├── speech_dataset_default.yaml
│   └── ...
├── figure/                     # Figures and diagrams
├── init_env.sh                # Environment setup script
└── README.md                  # This file

⚙️ Configuration

Dataset Configuration

Configure dataset splits using YAML files in the configs/ directory:

# Example: speech_dataset_default.yaml
trial_range: [0, 399]
subject_range: [1, 16]
test_split_by: trial
val_split_by: trial
test_split: [350, 399]
val_split: [300, 349]

Model Configuration

Key parameters for model configuration:

  • classifier: all_patch_reps, all_patch_reps_onelayer, avgpooling_patch_reps
  • dropout: Dropout rate (default: 0.1)
  • pretext: Pretext task selection
  • pretext_weight_*: Weights for different pretext tasks

🙏 Acknowledgments

  • CBraMod: Original implementation by wjq-learning
  • Tent: Original implementation from ICLR 2021 paper
  • All dataset providers and the EEG research community