Skip to content

Commit e2c4215

Browse files
committed
TST add more tests
1 parent af729a6 commit e2c4215

3 files changed

Lines changed: 158 additions & 14 deletions

File tree

.github/workflows/run_tests.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ jobs:
2424
restore-keys: |
2525
${{ runner.os }}-pip-
2626
27+
- uses: actions/cache@v3
28+
with:
29+
path: ~/.voxelwise_tutorials_data/shortclips
30+
key: ${{ runner.os }}-shortclips
31+
2732
- name: Install dependencies
2833
run: |
2934
pip install -U setuptools

voxelwise_tutorials/tests/test_mappers.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,35 @@
66
import os
77

88
import numpy as np
9+
import cortex
10+
from cortex.testing_utils import has_installed
911
import matplotlib.pyplot as plt
1012

13+
from voxelwise_tutorials.io import load_hdf5_array
1114
from voxelwise_tutorials.io import load_hdf5_sparse_array
1215
from voxelwise_tutorials.viz import plot_flatmap_from_mapper
1316
from voxelwise_tutorials.viz import plot_2d_flatmap_from_mapper
1417

1518
from voxelwise_tutorials.io import get_data_home
1619
from voxelwise_tutorials.io import download_datalad
1720

18-
dataset_directory = get_data_home(dataset="shortclips")
19-
subject_id = "S01"
21+
subject = "S01"
22+
directory = get_data_home(dataset="shortclips")
23+
file_name = os.path.join("mappers", f'{subject}_mappers.hdf')
24+
mapper_file = os.path.join(directory, file_name)
2025

2126
# download mapper if not already present
22-
download_datalad("mappers/S01_mappers.hdf", destination=dataset_directory,
27+
download_datalad(file_name, destination=directory,
2328
source="https://gin.g-node.org/gallantlab/shortclips")
2429

30+
# Change to save = True to save the figures locally and check the results
31+
save_fig = False
2532

26-
def test_flatmap_mappers():
2733

28-
# Change to save = True to save the figures locally and check the results
29-
save_fig = False
34+
def test_flatmap_mappers():
3035

3136
##################
3237
# create fake data
33-
mapper_file = os.path.join(dataset_directory, "mappers",
34-
'{}_mappers.hdf'.format(subject_id))
35-
3638
voxel_to_flatmap = load_hdf5_sparse_array(mapper_file, 'voxel_to_flatmap')
3739
voxels = np.linspace(0, 1, voxel_to_flatmap.shape[1])
3840

@@ -42,7 +44,7 @@ def test_flatmap_mappers():
4244
ax=None)
4345
fig = ax.figure
4446
if save_fig:
45-
fig.savefig(f'{subject_id}.png')
47+
fig.savefig(f'test.png')
4648
plt.close(fig)
4749

4850

@@ -53,9 +55,6 @@ def test_plot_2d_flatmap_from_mapper():
5355

5456
##################
5557
# create fake data
56-
mapper_file = os.path.join(dataset_directory, "mappers",
57-
'{}_mappers.hdf'.format(subject_id))
58-
5958
voxel_to_flatmap = load_hdf5_sparse_array(mapper_file, 'voxel_to_flatmap')
6059
phase = np.linspace(0, 2 * np.pi, voxel_to_flatmap.shape[1])
6160
sin = np.sin(phase)
@@ -67,5 +66,46 @@ def test_plot_2d_flatmap_from_mapper():
6766
vmin=-1, vmax=1, vmin2=-1, vmax2=1)
6867
fig = ax.figure
6968
if save_fig:
70-
fig.savefig(f'{subject_id}.png')
69+
fig.savefig(f'test_2d.png')
70+
plt.close(fig)
71+
72+
73+
def test_roi_masks_shape():
74+
all_mappers = load_hdf5_array(mapper_file, key=None)
75+
76+
n_pixels, n_voxels = all_mappers['voxel_to_flatmap_shape']
77+
n_vertices, n_voxels_ = all_mappers['voxel_to_fsaverage_shape']
78+
assert n_voxels_ == n_voxels
79+
80+
for key, val in all_mappers.items():
81+
if 'roi_mask_' in key:
82+
assert val.shape == (n_voxels, )
83+
84+
85+
def test_fsaverage_mappers():
86+
87+
# Change to save = True to save the figures locally and check the results
88+
save_fig = False
89+
90+
##################
91+
# create fake data
92+
voxel_to_fsaverage = load_hdf5_sparse_array(mapper_file,
93+
'voxel_to_fsaverage')
94+
voxels = np.linspace(0, 1, voxel_to_fsaverage.shape[1])
95+
96+
##################
97+
# download fsaverage subject
98+
if not hasattr(cortex.db, "fsaverage"):
99+
cortex.utils.download_subject(subject_id="fsaverage",
100+
pycortex_store=cortex.db.filestore)
101+
cortex.db.reload_subjects() # force filestore reload
102+
103+
#############################
104+
# plot with fsaverage mappers
105+
projected = voxel_to_fsaverage @ voxels
106+
vertex = cortex.Vertex(projected, 'fsaverage', vmin=0, vmax=0.3,
107+
cmap='inferno', alpha=0.7, with_curvature=True)
108+
fig = cortex.quickshow(vertex, with_rois=has_installed("inkscape"))
109+
if save_fig:
110+
fig.savefig(f'test_fsaverage.png')
71111
plt.close(fig)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import os
2+
3+
import pytest
4+
import numpy as np
5+
6+
from sklearn.model_selection import check_cv
7+
from sklearn.pipeline import make_pipeline
8+
from sklearn.preprocessing import StandardScaler
9+
10+
from himalaya.backend import set_backend
11+
from himalaya.kernel_ridge import KernelRidgeCV
12+
13+
from voxelwise_tutorials.delayer import Delayer
14+
from voxelwise_tutorials.io import load_hdf5_array
15+
from voxelwise_tutorials.io import get_data_home
16+
from voxelwise_tutorials.io import download_datalad
17+
from voxelwise_tutorials.utils import explainable_variance
18+
from voxelwise_tutorials.utils import generate_leave_one_run_out
19+
20+
# use "cupy" or "torch_cuda" for faster computation with GPU
21+
backend = set_backend("numpy", on_error="warn")
22+
23+
# Download the dataset
24+
subject = "S01"
25+
feature_spaces = ["motion_energy", "wordnet"]
26+
directory = get_data_home(dataset="shortclips")
27+
for file_name in [
28+
"features/motion_energy.hdf",
29+
"features/wordnet.hdf",
30+
"mappers/S01_mappers.hdf",
31+
"responses/S01_responses.hdf",
32+
]:
33+
download_datalad(file_name, destination=directory,
34+
source="https://gin.g-node.org/gallantlab/shortclips")
35+
36+
37+
def run_model(X_train, X_test, Y_train, Y_test, run_onsets):
38+
##############
39+
# define model
40+
n_samples_train = Y_train.shape[0]
41+
cv = generate_leave_one_run_out(n_samples_train, run_onsets,
42+
random_state=0, n_runs_out=1)
43+
cv = check_cv(cv)
44+
45+
alphas = np.logspace(-4, 15, 20)
46+
47+
model = make_pipeline(
48+
StandardScaler(with_mean=True, with_std=False),
49+
Delayer(delays=[1, 2, 3, 4]),
50+
KernelRidgeCV(
51+
kernel="linear", alphas=alphas, cv=cv,
52+
solver_params=dict(n_targets_batch=1000, n_alphas_batch=10)),
53+
)
54+
55+
###########
56+
# run model
57+
model.fit(X_train, Y_train)
58+
test_scores = model.score(X_test, Y_test)
59+
60+
test_scores = backend.to_numpy(test_scores)
61+
# cv_scores = backend.to_numpy(model[-1].cv_scores_)
62+
63+
return test_scores
64+
65+
66+
@pytest.mark.parametrize('feature_space', feature_spaces)
67+
def test_model_fitting(feature_space):
68+
###########################################
69+
# load the data
70+
71+
# load X
72+
features_file = os.path.join(directory, 'features',
73+
feature_space + ".hdf")
74+
features = load_hdf5_array(features_file)
75+
X_train = features['X_train']
76+
X_test = features['X_test']
77+
78+
# load Y
79+
responses_file = os.path.join(directory, 'responses',
80+
subject + "_responses.hdf")
81+
responses = load_hdf5_array(responses_file)
82+
Y_train = responses['Y_train']
83+
Y_test_repeats = responses['Y_test']
84+
run_onsets = responses['run_onsets']
85+
86+
#############################################
87+
# select voxels based on explainable variance
88+
ev = explainable_variance(Y_test_repeats)
89+
mask = ev > 0.4
90+
assert mask.sum() > 0
91+
Y_train = Y_train[:, mask]
92+
Y_test = Y_test_repeats[:, :, mask].mean(0)
93+
94+
###########################################
95+
# fit a ridge model and compute test scores
96+
test_scores = run_model(X_train, X_test, Y_train, Y_test, run_onsets)
97+
assert np.percentile(test_scores, 95) > 0.05
98+
assert np.percentile(test_scores, 99) > 0.15
99+
assert np.percentile(test_scores, 100) > 0.35

0 commit comments

Comments
 (0)