diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2f56c87..4b5e202 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,9 +2,8 @@ name: CI on: push: - branches: [ "dev", "master" ] + branches: ["dev", "master"] pull_request: - branches: [ "dev", "master" ] workflow_dispatch: jobs: diff --git a/.gitignore b/.gitignore index 1c51e0c..0a3e713 100644 --- a/.gitignore +++ b/.gitignore @@ -175,3 +175,6 @@ pyrightconfig.json # End of https://www.toptal.com/developers/gitignore/api/python *.lock + +# Raw EEG recordings +recordings/ diff --git a/bridge/eeg/__init__.py b/bridge/eeg/__init__.py index 19b3667..cd914db 100644 --- a/bridge/eeg/__init__.py +++ b/bridge/eeg/__init__.py @@ -1,6 +1,7 @@ from .config import close, init from .connector import EEGConnector from .core import DeviceData, EEGArray, EEGDevice +from .fif import FifDevice, FifRecorder __all__ = [ "DeviceData", @@ -9,4 +10,6 @@ "init", "close", "EEGConnector", + "FifDevice", + "FifRecorder", ] diff --git a/bridge/eeg/brainaccess/__init__.py b/bridge/eeg/brainaccess/__init__.py index a95ed53..3db8f56 100644 --- a/bridge/eeg/brainaccess/__init__.py +++ b/bridge/eeg/brainaccess/__init__.py @@ -1,3 +1,14 @@ -from .device import BrainaccessDevice +from .cap_factory import DEVICE_TO_CAP, get_cap_from_model, get_cap_from_name -__all__ = ["BrainaccessDevice"] +__all__ = [ + "DEVICE_TO_CAP", + "get_cap_from_model", + "get_cap_from_name", +] + +try: + from .device import BrainaccessDevice + + __all__ = [*__all__, "BrainaccessDevice"] +except ImportError: + pass diff --git a/bridge/eeg/brainaccess/device.py b/bridge/eeg/brainaccess/device.py index 0655b17..4148fee 100644 --- a/bridge/eeg/brainaccess/device.py +++ b/bridge/eeg/brainaccess/device.py @@ -50,7 +50,6 @@ def _connect(self, device_name: str, cap: dict[int, str]) -> None: try: self._eeg.setup(self._manager, device_name=device_name, cap=cap) self._electrodes = list(cap.values()) - self._logger.info("Connection successful.") except Exception: self._manager.__exit__(None, None, None) raise @@ -94,8 +93,10 @@ def connect( with connection_lock: self._logger.debug("Scanning for eeg...") - core.scan(adapter_index=bluetooth_adapter) - count = core.get_device_count() + if bluetooth_adapter != 0: + core.config_set_adapter_index(bluetooth_adapter) + devices = core.scan() + count = len(devices) self._logger.info(f"Found {count} eeg.") if count == 0: @@ -106,8 +107,8 @@ def connect( if port >= count: raise ConnectionError(f"Can't connect on port {port}, found {count} eeg.") - self._device_name = core.get_device_name(port) or "Unknown Device" - self._mac_address = core.get_device_address(port) + self._device_name = devices[port].name or "Unknown Device" + self._mac_address = devices[port].mac_address self._cap = get_cap_from_name(self._device_name) if not self._cap: @@ -131,10 +132,12 @@ def disconnect(self) -> None: self._is_streaming = False self._logger.debug("Disconnecting the device...") if self._manager: - self._manager.stop_stream() + try: + self._manager.stop_stream() + except Exception: + pass self._manager.disconnect() self._manager.__exit__(None, None, None) - # self._manager.destroy() self._manager = None self._eeg.close() @@ -175,8 +178,6 @@ def stream(self) -> Generator[EEGArray, None, None]: continue finally: self._is_streaming = False - if self._manager: - self._manager.stop_stream() self._logger.info("Stopped real-time stream.") # IM-032 diff --git a/bridge/eeg/core/device.py b/bridge/eeg/core/device.py index 32eddb1..d843e02 100644 --- a/bridge/eeg/core/device.py +++ b/bridge/eeg/core/device.py @@ -1,4 +1,6 @@ +import threading from abc import ABC, abstractmethod +from collections.abc import Callable from logging import Logger, getLogger from types import TracebackType from typing import Generator @@ -11,6 +13,29 @@ class EEGDevice(ABC): def __init__(self, logger: Logger | None = None) -> None: self._logger = logger or getLogger(__name__) self._logger.debug(f"{self.__class__.__name__} initialized.") + self._subscribers: list[Callable[[EEGArray], None]] = [] + self._push_thread: threading.Thread | None = None + + def subscribe(self, callback: Callable[[EEGArray], None]) -> None: + self._subscribers.append(callback) + + def start(self) -> None: + self._push_thread = threading.Thread(target=self._push_loop, daemon=True) + self._push_thread.start() + + def stop(self) -> None: + self.disconnect() + if self._push_thread is not None: + self._push_thread.join(timeout=5) + self._push_thread = None + + def _push_loop(self) -> None: + for chunk in self.stream(): + for cb in list(self._subscribers): + try: + cb(chunk) + except Exception: + self._logger.exception("Subscriber %r raised", cb) @abstractmethod def connect(self) -> None: diff --git a/bridge/eeg/fif/__init__.py b/bridge/eeg/fif/__init__.py new file mode 100644 index 0000000..d6602c9 --- /dev/null +++ b/bridge/eeg/fif/__init__.py @@ -0,0 +1,4 @@ +from .device import FifDevice +from .recorder import FifRecorder + +__all__ = ["FifDevice", "FifRecorder"] diff --git a/bridge/eeg/fif/device.py b/bridge/eeg/fif/device.py new file mode 100644 index 0000000..cae9ebb --- /dev/null +++ b/bridge/eeg/fif/device.py @@ -0,0 +1,66 @@ +import time +import warnings +from logging import Logger, getLogger +from pathlib import Path +from typing import Any, Final, Generator + +import numpy as np + +from ..core import DeviceData, EEGArray, EEGDevice + + +class FifDevice(EEGDevice): + def __init__(self, file_path: str, chunk_size: int = 25, logger: Logger | None = None) -> None: + try: + import mne # noqa: F401 + except ImportError as e: + raise ImportError("FIF support requires mne: pip install 'neuron-bridge[fif]'") from e + + super().__init__(logger or getLogger(__name__)) + self._path: Final[Path] = Path(file_path) + self._chunk_size: Final[int] = chunk_size + self._data: np.ndarray[Any, Any] | None = None + self._sfreq: float = 250.0 + self._is_connected: bool = False + + def connect(self) -> None: + import mne + + if not self._path.exists(): + raise FileNotFoundError(f"FIF file not found: {self._path}") + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + raw = mne.io.read_raw_fif(str(self._path), preload=True, verbose=False) + self._sfreq = float(raw.info["sfreq"]) + self._data = raw.get_data() + self._is_connected = True + if self._data is not None: + data: np.ndarray[Any, Any] = self._data + self._logger.info("FifDevice connected: %d ch × %d samples @ %.0f Hz", *data.shape, self._sfreq) + + def disconnect(self) -> None: + self._is_connected = False + + def stream(self) -> Generator[EEGArray, None, None]: + if not self._is_connected or self._data is None: + raise RuntimeError("FifDevice not connected.") + + n_samples = self._data.shape[1] + interval: Final[float] = self._chunk_size / self._sfreq + start_perf: Final[float] = time.perf_counter() + + for count, start in enumerate(range(0, n_samples - self._chunk_size + 1, self._chunk_size), start=1): + if not self._is_connected: + break + + target: float = start_perf + count * interval + while time.perf_counter() < target: + diff = target - time.perf_counter() + if diff > 0.002: + time.sleep(diff - 0.001) + + yield self._data[:, start : start + self._chunk_size].astype(np.float64) + + def get_device_data(self) -> DeviceData: + return DeviceData(name=self._path.name, manufacturer="FifSim", sample_rate=int(self._sfreq)) diff --git a/bridge/eeg/fif/recorder.py b/bridge/eeg/fif/recorder.py new file mode 100644 index 0000000..8c42c8a --- /dev/null +++ b/bridge/eeg/fif/recorder.py @@ -0,0 +1,77 @@ +import time +import warnings +from logging import Logger, getLogger +from pathlib import Path +from typing import Any, Final, Generator + +import numpy as np + +from ..core import EEGArray, EEGDevice +from ..core.device_data import RecordingFrame + + +class FifRecorder: + def __init__( + self, + device: EEGDevice, + filename: str, + cap: dict[int, str], + sfreq: float = 250.0, + logger: Logger | None = None, + autosave: bool = True, + connect_device: bool = True, + ) -> None: + try: + import mne # noqa: F401 + except ImportError as e: + raise ImportError("FIF support requires mne: pip install 'neuron-bridge[fif]'") from e + + self._logger: Final[Logger] = logger or getLogger(__name__) + self._device: Final[EEGDevice] = device + self._filename: Final[str] = filename + self._cap: Final[dict[int, str]] = cap + self._sfreq: Final[float] = sfreq + self._autosave: Final[bool] = autosave + self._connect_device: Final[bool] = connect_device + self._frames: list[RecordingFrame] = [] + + def __enter__(self) -> "FifRecorder": + if self._connect_device: + self._device.connect() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._autosave: + self.save() + if self._connect_device: + self._device.disconnect() + + def stream(self) -> Generator[EEGArray, None, None]: + for chunk in self._device.stream(): + self._frames.append(RecordingFrame(timestamp=time.time(), data=chunk)) + yield chunk + + def save(self) -> None: + import mne + + if not self._frames: + self._logger.warning("No data to save.") + return + + try: + output_dir: Final[Path] = Path("recordings") + output_dir.mkdir(exist_ok=True) + file_path: Final[Path] = output_dir / self._filename + + data = np.concatenate([f.data for f in self._frames], axis=1) + ch_names = [self._cap[i] for i in sorted(self._cap)] + info = mne.create_info(ch_names=ch_names, sfreq=self._sfreq, ch_types="eeg") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + mne.io.RawArray(data, info, verbose=False).save(str(file_path), overwrite=True, verbose=False) + + self._logger.info("Saved session to FIF: %s", file_path) + + except (OSError, IOError) as e: + self._logger.error("Failed to save FIF recording: %s", e) + raise diff --git a/bridge/eeg/file/device.py b/bridge/eeg/file/device.py index 60b5e84..4fc8ed2 100644 --- a/bridge/eeg/file/device.py +++ b/bridge/eeg/file/device.py @@ -1,7 +1,7 @@ import time from logging import Logger, getLogger from pathlib import Path -from typing import Final, Generator +from typing import Any, Final, Generator import numpy as np @@ -11,11 +11,14 @@ class FileDevice(EEGDevice): """Emulator odtwarzający sesje z plików binarnych .npz.""" - def __init__(self, file_path: str, sfreq: float = 250.0, logger: Logger | None = None) -> None: + def __init__( + self, file_path: str, sfreq: float = 250.0, chunk_size: int = 25, logger: Logger | None = None + ) -> None: super().__init__(logger or getLogger(__name__)) self._path: Final[Path] = Path(file_path) self._sfreq: Final[float] = sfreq - self._data: np.ndarray | None = None + self._chunk_size: Final[int] = chunk_size + self._data: np.ndarray[Any, Any] | None = None self._is_connected: bool = False def connect(self) -> None: @@ -26,10 +29,10 @@ def connect(self) -> None: self._data = loader["data"] self._is_connected = True - if not self._data: + if self._data is None or self._data.size == 0: raise ValueError(f"No data found in file: {self._path}") - self._logger.info("FileDevice connected. Loaded %d blocks.", len(self._data)) + self._logger.info("FileDevice connected. Data shape: %s", self._data.shape) def disconnect(self) -> None: self._is_connected = False @@ -38,23 +41,21 @@ def stream(self) -> Generator[EEGArray, None, None]: if not self._is_connected or self._data is None: raise RuntimeError("FileDevice not connected.") - chunk_size: Final[int] = self._data.shape[2] - interval: Final[float] = chunk_size / self._sfreq - + n_samples = self._data.shape[1] + interval: Final[float] = self._chunk_size / self._sfreq start_perf: Final[float] = time.perf_counter() - for count, chunk in enumerate(self._data, start=1): + for count, start in enumerate(range(0, n_samples - self._chunk_size + 1, self._chunk_size), start=1): if not self._is_connected: break target: float = start_perf + (count * interval) - while time.perf_counter() < target: diff = target - time.perf_counter() if diff > 0.002: time.sleep(diff - 0.001) - yield chunk + yield self._data[:, start : start + self._chunk_size].astype(np.float64) def get_device_data(self) -> DeviceData: return DeviceData(name=self._path.name, manufacturer="BinarySim", sample_rate=int(self._sfreq)) diff --git a/bridge/eeg/recorder.py b/bridge/eeg/recorder.py index 7c80e1c..c386485 100644 --- a/bridge/eeg/recorder.py +++ b/bridge/eeg/recorder.py @@ -12,21 +12,31 @@ class EEGRecorder: """Rejestrator EEG wykorzystujący wysokowydajny format binarny NumPy.""" - def __init__(self, device: EEGDevice, filename: str, logger: Logger | None = None, autosave: bool = True) -> None: + def __init__( + self, + device: EEGDevice, + filename: str, + logger: Logger | None = None, + autosave: bool = True, + connect_device: bool = True, + ) -> None: self._logger: Final[Logger] = logger or getLogger(__name__) self._device: Final[EEGDevice] = device self._filename: Final[str] = filename self._autosave: Final[bool] = autosave + self._connect_device: Final[bool] = connect_device self._frames: list[RecordingFrame] = [] def __enter__(self) -> "EEGRecorder": - self._device.connect() + if self._connect_device: + self._device.connect() return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self._autosave: self.save() - self._device.disconnect() + if self._connect_device: + self._device.disconnect() def stream(self) -> Generator[EEGArray, None, None]: """Strumieniuje dane i buforuje je w pamięci jako RecordingFrame.""" @@ -45,8 +55,8 @@ def save(self) -> None: output_dir.mkdir(exist_ok=True) file_path: Final[Path] = output_dir / self._filename - timestamps: Final[np.ndarray] = np.array([f.timestamp for f in self._frames]) - data_blocks: Final[np.ndarray] = np.array([f.data for f in self._frames]) + timestamps: Final[np.ndarray[Any, Any]] = np.array([f.timestamp for f in self._frames]) + data_blocks: Final[np.ndarray[Any, Any]] = np.concatenate([f.data for f in self._frames], axis=1) np.savez_compressed(file_path, timestamps=timestamps, data=data_blocks) diff --git a/examples/playback_fif.py b/examples/playback_fif.py new file mode 100644 index 0000000..b0955fb --- /dev/null +++ b/examples/playback_fif.py @@ -0,0 +1,30 @@ +from bridge.eeg.fif import FifDevice + + +def playback_session_fif() -> None: + file_path = "recordings/my_brain_data.fif" + + try: + device = FifDevice(file_path=file_path, chunk_size=25) + + print(f"Otwieranie pliku: {file_path}") + + with device: + info = device.get_device_data() + print(f"Symulacja urządzenia: {info.manufacturer} (Źródło: {info.name})") + print(f"Częstotliwość próbkowania: {info.sample_rate} Hz") + + print("Rozpoczynam odtwarzanie strumienia...") + + for i, chunk in enumerate(device.stream()): + avg_signal = chunk.mean() + print(f"Ramka {i:03} | Średnie napięcie: {avg_signal:.4f} uV") + + except FileNotFoundError: + print(f"Błąd: Nie znaleziono pliku {file_path}. Najpierw uruchom record_fif.py") + except Exception as e: + print(f"Wystąpił błąd: {e}") + + +if __name__ == "__main__": + playback_session_fif() diff --git a/examples/record_fif.py b/examples/record_fif.py new file mode 100644 index 0000000..d4b785e --- /dev/null +++ b/examples/record_fif.py @@ -0,0 +1,39 @@ +import time + +from bridge.eeg import EEGConnector +from bridge.eeg.brainaccess import get_cap_from_model +from bridge.eeg.fif import FifRecorder + + +def record_session_fif() -> None: + try: + with EEGConnector() as connector: + device = connector._eeg_device + if not device: + print("Nie znaleziono urządzenia!") + return + + print(f"Połączono z: {device.get_device_data().name}") + + cap = get_cap_from_model("MAXI") + + with FifRecorder( + device, filename="my_brain_data.fif", cap=cap, sfreq=250.0, connect_device=False + ) as recorder: + print("Rozpoczynam zbieranie danych (10 sekund)...") + + start_time = time.time() + for chunk in recorder.stream(): + print(f"Odebrano paczkę o kształcie: {chunk.shape}") + + if time.time() - start_time > 10: + break + + print("Zakończono zbieranie. Zapisywanie do .fif...") + + except Exception as e: + print(f"Wystąpił błąd: {e}") + + +if __name__ == "__main__": + record_session_fif() diff --git a/examples/record_stream.py b/examples/record_stream.py index 7601e1b..79e1044 100644 --- a/examples/record_stream.py +++ b/examples/record_stream.py @@ -1,15 +1,12 @@ import time -from bridge.eeg import EEGConnector, close, init +from bridge.eeg import EEGConnector from bridge.eeg.recorder import EEGRecorder def record_session() -> None: - # 1. Inicjalizacja sterowników SDK - init() - try: - # 2. Używamy Connectora, aby automatycznie znalazł urządzenie + # 1. Używamy Connectora, aby automatycznie znalazł urządzenie with EEGConnector() as connector: device = connector._eeg_device # Pobieramy dostęp do instancji urządzenia if not device: @@ -18,9 +15,9 @@ def record_session() -> None: print(f"Połączono z: {device.get_device_data().name}") - # 3. Tworzymy rekorder (automatycznie zapisze do .npz przy wyjściu z context managera) + # 2. Tworzymy rekorder (automatycznie zapisze do .npz przy wyjściu z context managera) # Plik trafi do folderu recordings/my_brain_data.npz - with EEGRecorder(device, filename="my_brain_data.npz") as recorder: + with EEGRecorder(device, filename="my_brain_data.npz", connect_device=False) as recorder: print("Rozpoczynam zbieranie danych (10 sekund)...") start_time = time.time() @@ -36,9 +33,6 @@ def record_session() -> None: except Exception as e: print(f"Wystąpił błąd: {e}") - finally: - # 4. Zwolnienie zasobów SDK - close() if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index e5bff50..d488e4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,10 +33,15 @@ dependencies = [ server = [ "websockets~=12.0", ] +fif = [ + "mne~=1.6.1", + "scipy<1.15.0" +] dev = [ "websockets~=12.0", "pandas~=2.3.2", "mne~=1.6.1", + "scipy<1.15.0", "matplotlib~=3.10.6", "black~=25.1.0", "ruff~=0.13.0", diff --git a/tests/eeg/brainaccess/test_brainaccess.py b/tests/eeg/brainaccess/test_brainaccess.py index 1414244..5c60278 100644 --- a/tests/eeg/brainaccess/test_brainaccess.py +++ b/tests/eeg/brainaccess/test_brainaccess.py @@ -44,21 +44,18 @@ def test_connect_no_devices_found(mock_brainaccess_sdk): def test_connect_successful(mock_brainaccess_sdk): """Test a successful connection flow.""" - mock_brainaccess_sdk.get_device_count.return_value = 1 - mock_brainaccess_sdk.get_device_name.return_value = "BRAINACCESS-MAXI-1234" - mock_brainaccess_sdk.get_device_address.return_value = "00:11:22:33:44:55" - - manager_instance = mock_brainaccess_sdk.EEGManager() + mock_device_info = MagicMock() + mock_device_info.name = "BA MAXI 009" + mock_device_info.mac_address = "00:11:22:33:44:55" + mock_brainaccess_sdk.scan.return_value = [mock_device_info] device = BrainaccessDevice() device.connect(port=0) mock_brainaccess_sdk.scan.assert_called_once() - mock_brainaccess_sdk.get_device_name.assert_called_with(0) - manager_instance.__enter__.assert_called_once() - assert device._device_name == "BRAINACCESS-MAXI-1234" + assert device._device_name == "BA MAXI 009" assert device._mac_address == "00:11:22:33:44:55" - assert "P8" in device._cap.values() # Check if MAXI cap was loaded + assert "P8" in device._cap.values() # MAXI cap loaded by name def test_get_output_calls_sdk_correctly(mock_brainaccess_sdk): diff --git a/tests/eeg/test_fif.py b/tests/eeg/test_fif.py new file mode 100644 index 0000000..3b9a37b --- /dev/null +++ b/tests/eeg/test_fif.py @@ -0,0 +1,129 @@ +import threading +import time +from pathlib import Path + +import numpy as np +import pytest + +from bridge.eeg.core import EEGDevice +from bridge.eeg.core.device_data import DeviceData +from bridge.eeg.fif import FifDevice, FifRecorder + +pytest.importorskip("mne", reason="mne not installed — skipping FIF tests") + +_CAP = {0: "C3", 1: "C4", 2: "Cz", 3: "Fz"} +_N_CH = len(_CAP) +_CHUNK = 25 +_SFREQ = 250.0 + + +class _FakeDevice(EEGDevice): + def __init__(self, n_chunks: int = 8) -> None: + super().__init__() + self._n = n_chunks + self._on = False + + def connect(self) -> None: + self._on = True + + def disconnect(self) -> None: + self._on = False + + def stream(self): + rng = np.random.default_rng(0) + for _ in range(self._n): + yield rng.standard_normal((_N_CH, _CHUNK)) + + def get_device_data(self) -> DeviceData: + return DeviceData(name="FakeDevice", sample_rate=int(_SFREQ)) + + +def _record(tmp_path: Path, n_chunks: int = 8) -> Path: + fif_path = tmp_path / "session.fif" + dev = _FakeDevice(n_chunks=n_chunks) + dev.connect() + rec = FifRecorder(dev, str(fif_path), cap=_CAP, sfreq=_SFREQ) + chunks = [] + + def _run(): + for c in rec.stream(): + chunks.append(c) + + t = threading.Thread(target=_run, daemon=True) + t.start() + time.sleep(0.5) + dev.disconnect() + t.join(timeout=2) + rec.save() + return fif_path + + +def test_fif_recorder_creates_file(tmp_path): + path = _record(tmp_path) + assert path.exists() + assert path.stat().st_size > 0 + + +def test_fif_device_connects_without_error(tmp_path): + path = _record(tmp_path) + device = FifDevice(str(path), chunk_size=_CHUNK) + device.connect() + device.disconnect() + + +def test_fif_device_chunk_shape(tmp_path): + path = _record(tmp_path) + device = FifDevice(str(path), chunk_size=_CHUNK) + device.connect() + chunks = list(device.stream()) + device.disconnect() + assert len(chunks) > 0 + for chunk in chunks: + assert chunk.shape == (_N_CH, _CHUNK) + + +def test_fif_device_chunk_dtype(tmp_path): + path = _record(tmp_path) + device = FifDevice(str(path), chunk_size=_CHUNK) + device.connect() + chunks = list(device.stream()) + device.disconnect() + assert chunks[0].dtype == np.float64 + + +def test_fif_device_missing_file_raises(): + device = FifDevice("nonexistent.fif", chunk_size=_CHUNK) + with pytest.raises(FileNotFoundError): + device.connect() + + +def test_fif_roundtrip_data_matches(tmp_path): + fif_path = tmp_path / "rt.fif" + rng = np.random.default_rng(42) + original = rng.standard_normal((_N_CH, _CHUNK * 4)) + + import warnings + + import mne + + info = mne.create_info(list(_CAP.values()), sfreq=_SFREQ, ch_types="eeg") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + mne.io.RawArray(original, info, verbose=False).save(str(fif_path), overwrite=True, verbose=False) + + device = FifDevice(str(fif_path), chunk_size=_CHUNK) + device.connect() + chunks = list(device.stream()) + device.disconnect() + + recovered = np.concatenate(chunks, axis=1) + np.testing.assert_allclose(recovered, original[:, : recovered.shape[1]], atol=1e-10) + + +def test_fif_device_get_device_data(tmp_path): + path = _record(tmp_path) + device = FifDevice(str(path), chunk_size=_CHUNK) + device.connect() + info = device.get_device_data() + device.disconnect() + assert info.sample_rate == int(_SFREQ) diff --git a/tests/eeg/test_push_model.py b/tests/eeg/test_push_model.py new file mode 100644 index 0000000..f4a2a5d --- /dev/null +++ b/tests/eeg/test_push_model.py @@ -0,0 +1,104 @@ +import time + +import numpy as np + +from bridge.eeg.core import EEGDevice +from bridge.eeg.core.device_data import DeviceData + + +class _StubDevice(EEGDevice): + def __init__(self, n_chunks: int = 5, chunk_shape: tuple = (4, 10)) -> None: + super().__init__() + self._n_chunks = n_chunks + self._chunk_shape = chunk_shape + self._connected = False + + def connect(self) -> None: + self._connected = True + + def disconnect(self) -> None: + self._connected = False + + def stream(self): + for i in range(self._n_chunks): + if not self._connected: + return + yield np.full(self._chunk_shape, float(i), dtype=np.float64) + + def get_device_data(self) -> DeviceData: + return DeviceData(name="StubDevice", sample_rate=250) + + +def test_subscribe_registers_callback(): + device = _StubDevice() + cb = lambda chunk: None # noqa: E731 + device.subscribe(cb) + assert cb in device._subscribers + + +def test_start_pushes_chunks_to_subscriber(): + device = _StubDevice(n_chunks=4) + received: list[np.ndarray] = [] + + device.connect() + device.subscribe(received.append) + device.start() + device._push_thread.join(timeout=2) + + assert len(received) == 4 + assert received[0].shape == (4, 10) + + +def test_stop_joins_thread(): + device = _StubDevice(n_chunks=10) + device.connect() + device.subscribe(lambda _: time.sleep(0.01)) + device.start() + device.stop() + + assert device._push_thread is None + assert not device._connected + + +def test_multiple_subscribers_all_receive(): + device = _StubDevice(n_chunks=3) + received_a: list = [] + received_b: list = [] + + device.connect() + device.subscribe(received_a.append) + device.subscribe(received_b.append) + device.start() + device._push_thread.join(timeout=2) + + assert len(received_a) == 3 + assert len(received_b) == 3 + + +def test_subscriber_exception_does_not_crash_push_loop(): + device = _StubDevice(n_chunks=3) + received: list = [] + + def bad_cb(chunk): + raise RuntimeError("boom") + + device.connect() + device.subscribe(bad_cb) + device.subscribe(received.append) + device.start() + device._push_thread.join(timeout=2) + + assert len(received) == 3 + + +def test_chunk_values_are_correct(): + device = _StubDevice(n_chunks=3) + received: list[np.ndarray] = [] + + device.connect() + device.subscribe(received.append) + device.start() + device._push_thread.join(timeout=2) + + for i, chunk in enumerate(received): + np.testing.assert_array_equal(chunk, np.full((4, 10), float(i)))