Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ name: CI

on:
push:
branches: [ "dev", "master" ]
branches: ["dev", "master"]
pull_request:
branches: [ "dev", "master" ]
workflow_dispatch:

jobs:
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,6 @@ pyrightconfig.json

# End of https://www.toptal.com/developers/gitignore/api/python
*.lock

# Raw EEG recordings
recordings/
3 changes: 3 additions & 0 deletions bridge/eeg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .config import close, init
from .connector import EEGConnector
from .core import DeviceData, EEGArray, EEGDevice
from .fif import FifDevice, FifRecorder

__all__ = [
"DeviceData",
Expand All @@ -9,4 +10,6 @@
"init",
"close",
"EEGConnector",
"FifDevice",
"FifRecorder",
]
15 changes: 13 additions & 2 deletions bridge/eeg/brainaccess/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
from .device import BrainaccessDevice
from .cap_factory import DEVICE_TO_CAP, get_cap_from_model, get_cap_from_name

__all__ = ["BrainaccessDevice"]
__all__ = [
"DEVICE_TO_CAP",
"get_cap_from_model",
"get_cap_from_name",
]

try:
from .device import BrainaccessDevice

__all__ = [*__all__, "BrainaccessDevice"]
except ImportError:
pass
19 changes: 10 additions & 9 deletions bridge/eeg/brainaccess/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def _connect(self, device_name: str, cap: dict[int, str]) -> None:
try:
self._eeg.setup(self._manager, device_name=device_name, cap=cap)
self._electrodes = list(cap.values())
self._logger.info("Connection successful.")
except Exception:
self._manager.__exit__(None, None, None)
raise
Expand Down Expand Up @@ -94,8 +93,10 @@ def connect(

with connection_lock:
self._logger.debug("Scanning for eeg...")
core.scan(adapter_index=bluetooth_adapter)
count = core.get_device_count()
if bluetooth_adapter != 0:
core.config_set_adapter_index(bluetooth_adapter)
devices = core.scan()
count = len(devices)
self._logger.info(f"Found {count} eeg.")

if count == 0:
Expand All @@ -106,8 +107,8 @@ def connect(
if port >= count:
raise ConnectionError(f"Can't connect on port {port}, found {count} eeg.")

self._device_name = core.get_device_name(port) or "Unknown Device"
self._mac_address = core.get_device_address(port)
self._device_name = devices[port].name or "Unknown Device"
self._mac_address = devices[port].mac_address
self._cap = get_cap_from_name(self._device_name)

if not self._cap:
Expand All @@ -131,10 +132,12 @@ def disconnect(self) -> None:
self._is_streaming = False
self._logger.debug("Disconnecting the device...")
if self._manager:
self._manager.stop_stream()
try:
self._manager.stop_stream()
except Exception:
pass
self._manager.disconnect()
self._manager.__exit__(None, None, None)
# self._manager.destroy()
self._manager = None
self._eeg.close()

Expand Down Expand Up @@ -175,8 +178,6 @@ def stream(self) -> Generator[EEGArray, None, None]:
continue
finally:
self._is_streaming = False
if self._manager:
self._manager.stop_stream()
self._logger.info("Stopped real-time stream.")

# IM-032
Expand Down
25 changes: 25 additions & 0 deletions bridge/eeg/core/device.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import threading
from abc import ABC, abstractmethod
from collections.abc import Callable
from logging import Logger, getLogger
from types import TracebackType
from typing import Generator
Expand All @@ -11,6 +13,29 @@ class EEGDevice(ABC):
def __init__(self, logger: Logger | None = None) -> None:
self._logger = logger or getLogger(__name__)
self._logger.debug(f"{self.__class__.__name__} initialized.")
self._subscribers: list[Callable[[EEGArray], None]] = []
self._push_thread: threading.Thread | None = None

def subscribe(self, callback: Callable[[EEGArray], None]) -> None:
self._subscribers.append(callback)

def start(self) -> None:
self._push_thread = threading.Thread(target=self._push_loop, daemon=True)
self._push_thread.start()

def stop(self) -> None:
self.disconnect()
if self._push_thread is not None:
self._push_thread.join(timeout=5)
self._push_thread = None

def _push_loop(self) -> None:
for chunk in self.stream():
for cb in list(self._subscribers):
try:
cb(chunk)
except Exception:
self._logger.exception("Subscriber %r raised", cb)

@abstractmethod
def connect(self) -> None:
Expand Down
4 changes: 4 additions & 0 deletions bridge/eeg/fif/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .device import FifDevice
from .recorder import FifRecorder

__all__ = ["FifDevice", "FifRecorder"]
66 changes: 66 additions & 0 deletions bridge/eeg/fif/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import time
import warnings
from logging import Logger, getLogger
from pathlib import Path
from typing import Any, Final, Generator

import numpy as np

from ..core import DeviceData, EEGArray, EEGDevice


class FifDevice(EEGDevice):
def __init__(self, file_path: str, chunk_size: int = 25, logger: Logger | None = None) -> None:
try:
import mne # noqa: F401
except ImportError as e:
raise ImportError("FIF support requires mne: pip install 'neuron-bridge[fif]'") from e

super().__init__(logger or getLogger(__name__))
self._path: Final[Path] = Path(file_path)
self._chunk_size: Final[int] = chunk_size
self._data: np.ndarray[Any, Any] | None = None
self._sfreq: float = 250.0
self._is_connected: bool = False

def connect(self) -> None:
import mne

if not self._path.exists():
raise FileNotFoundError(f"FIF file not found: {self._path}")

with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
raw = mne.io.read_raw_fif(str(self._path), preload=True, verbose=False)
self._sfreq = float(raw.info["sfreq"])
self._data = raw.get_data()
self._is_connected = True
if self._data is not None:
data: np.ndarray[Any, Any] = self._data
self._logger.info("FifDevice connected: %d ch × %d samples @ %.0f Hz", *data.shape, self._sfreq)

def disconnect(self) -> None:
self._is_connected = False

def stream(self) -> Generator[EEGArray, None, None]:
if not self._is_connected or self._data is None:
raise RuntimeError("FifDevice not connected.")

n_samples = self._data.shape[1]
interval: Final[float] = self._chunk_size / self._sfreq
start_perf: Final[float] = time.perf_counter()

for count, start in enumerate(range(0, n_samples - self._chunk_size + 1, self._chunk_size), start=1):
if not self._is_connected:
break

target: float = start_perf + count * interval
while time.perf_counter() < target:
diff = target - time.perf_counter()
if diff > 0.002:
time.sleep(diff - 0.001)

yield self._data[:, start : start + self._chunk_size].astype(np.float64)

def get_device_data(self) -> DeviceData:
return DeviceData(name=self._path.name, manufacturer="FifSim", sample_rate=int(self._sfreq))
77 changes: 77 additions & 0 deletions bridge/eeg/fif/recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import time
import warnings
from logging import Logger, getLogger
from pathlib import Path
from typing import Any, Final, Generator

import numpy as np

from ..core import EEGArray, EEGDevice
from ..core.device_data import RecordingFrame


class FifRecorder:
def __init__(
self,
device: EEGDevice,
filename: str,
cap: dict[int, str],
sfreq: float = 250.0,
logger: Logger | None = None,
autosave: bool = True,
connect_device: bool = True,
) -> None:
try:
import mne # noqa: F401
except ImportError as e:
raise ImportError("FIF support requires mne: pip install 'neuron-bridge[fif]'") from e

self._logger: Final[Logger] = logger or getLogger(__name__)
self._device: Final[EEGDevice] = device
self._filename: Final[str] = filename
self._cap: Final[dict[int, str]] = cap
self._sfreq: Final[float] = sfreq
self._autosave: Final[bool] = autosave
self._connect_device: Final[bool] = connect_device
self._frames: list[RecordingFrame] = []

def __enter__(self) -> "FifRecorder":
if self._connect_device:
self._device.connect()
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._autosave:
self.save()
if self._connect_device:
self._device.disconnect()

def stream(self) -> Generator[EEGArray, None, None]:
for chunk in self._device.stream():
self._frames.append(RecordingFrame(timestamp=time.time(), data=chunk))
yield chunk

def save(self) -> None:
import mne

if not self._frames:
self._logger.warning("No data to save.")
return

try:
output_dir: Final[Path] = Path("recordings")
output_dir.mkdir(exist_ok=True)
file_path: Final[Path] = output_dir / self._filename

data = np.concatenate([f.data for f in self._frames], axis=1)
ch_names = [self._cap[i] for i in sorted(self._cap)]
info = mne.create_info(ch_names=ch_names, sfreq=self._sfreq, ch_types="eeg")
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
mne.io.RawArray(data, info, verbose=False).save(str(file_path), overwrite=True, verbose=False)

self._logger.info("Saved session to FIF: %s", file_path)

except (OSError, IOError) as e:
self._logger.error("Failed to save FIF recording: %s", e)
raise
23 changes: 12 additions & 11 deletions bridge/eeg/file/device.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
from logging import Logger, getLogger
from pathlib import Path
from typing import Final, Generator
from typing import Any, Final, Generator

import numpy as np

Expand All @@ -11,11 +11,14 @@
class FileDevice(EEGDevice):
"""Emulator odtwarzający sesje z plików binarnych .npz."""

def __init__(self, file_path: str, sfreq: float = 250.0, logger: Logger | None = None) -> None:
def __init__(
self, file_path: str, sfreq: float = 250.0, chunk_size: int = 25, logger: Logger | None = None
) -> None:
super().__init__(logger or getLogger(__name__))
self._path: Final[Path] = Path(file_path)
self._sfreq: Final[float] = sfreq
self._data: np.ndarray | None = None
self._chunk_size: Final[int] = chunk_size
self._data: np.ndarray[Any, Any] | None = None
self._is_connected: bool = False

def connect(self) -> None:
Expand All @@ -26,10 +29,10 @@ def connect(self) -> None:
self._data = loader["data"]

self._is_connected = True
if not self._data:
if self._data is None or self._data.size == 0:
raise ValueError(f"No data found in file: {self._path}")

self._logger.info("FileDevice connected. Loaded %d blocks.", len(self._data))
self._logger.info("FileDevice connected. Data shape: %s", self._data.shape)

def disconnect(self) -> None:
self._is_connected = False
Expand All @@ -38,23 +41,21 @@ def stream(self) -> Generator[EEGArray, None, None]:
if not self._is_connected or self._data is None:
raise RuntimeError("FileDevice not connected.")

chunk_size: Final[int] = self._data.shape[2]
interval: Final[float] = chunk_size / self._sfreq

n_samples = self._data.shape[1]
interval: Final[float] = self._chunk_size / self._sfreq
start_perf: Final[float] = time.perf_counter()

for count, chunk in enumerate(self._data, start=1):
for count, start in enumerate(range(0, n_samples - self._chunk_size + 1, self._chunk_size), start=1):
if not self._is_connected:
break

target: float = start_perf + (count * interval)

while time.perf_counter() < target:
diff = target - time.perf_counter()
if diff > 0.002:
time.sleep(diff - 0.001)

yield chunk
yield self._data[:, start : start + self._chunk_size].astype(np.float64)

def get_device_data(self) -> DeviceData:
return DeviceData(name=self._path.name, manufacturer="BinarySim", sample_rate=int(self._sfreq))
Loading
Loading