Skip to content

Commit 7798495

Browse files
committed
ENH add explainable_variance function
1 parent c4f5d64 commit 7798495

4 files changed

Lines changed: 56 additions & 21 deletions

File tree

tutorials/movies_3T/01_plot_explainable_variance.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,21 @@
4141

4242
###############################################################################
4343
# Then, we compute the explainable variance per voxel.
44-
# The variance of the signal is estimated by taking the variance over
45-
# time (``var``). The variance of the component shared across repeats
46-
# is estimated by taking the variance of the average response (``var_mean``).
47-
# Then, we can compute the explainable variance by dividing the two quantities.
44+
# The variance of the signal is estimated by taking the average variance over
45+
# repeats. The variance of the component shared across repeats is estimated by
46+
# taking the variance of the average response. Then, we compute the
47+
# explainable variance by dividing these two quantities.
48+
# Finally, an correction can be applied to account for small numbers of repeat.
4849

49-
var = np.var(Y_test.reshape(-1, Y_test.shape[-1]), axis=0)
50-
var_mean = np.var(np.mean(Y_test, axis=0), axis=0)
51-
explainable_variance = var_mean / var
50+
from voxelwise.utils import explainable_variance
51+
52+
ev = explainable_variance(Y_test, bias_correction=False)
5253

5354
###############################################################################
5455
# Plot the distribution of explainable variance over voxels.
5556
import matplotlib.pyplot as plt
5657

57-
plt.hist(explainable_variance, bins=np.linspace(0, 1, 100), log=True,
58+
plt.hist(ev, bins=np.linspace(0, 1, 100), log=True,
5859
histtype='step')
5960
plt.xlabel("Explainable variance")
6061
plt.ylabel("Number of voxels")
@@ -68,7 +69,7 @@
6869
from voxelwise.viz import plot_flatmap_from_mapper
6970

7071
mapper_file = os.path.join(directory, 'mappers', f'{subject}_mappers.hdf')
71-
plot_flatmap_from_mapper(explainable_variance, mapper_file, vmin=0, vmax=0.7)
72+
plot_flatmap_from_mapper(ev, mapper_file, vmin=0, vmax=0.7)
7273
plt.show()
7374

7475
###############################################################################

tutorials/movies_3T/README.rst

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@ This tutorial implements different voxelwise models:
2323
**Requirements:**
2424
This tutorial requires the following Python packages:
2525

26-
- numpy (for the data array)
27-
- scipy (for motion energy extraction)
28-
- h5py (for loading the data files)
29-
- scikit-learn (for preprocessing and modeling)
30-
- himalaya (for modeling)
31-
- pymoten (for extracting motion energy)
26+
- numpy
27+
- scipy
28+
- h5py
29+
- scikit-learn
3230
- voxelwise (this repository)
31+
- himalaya
32+
- pymoten (optional, for extracting motion energy)
33+
- cupy/pytorch (optional, to use GPU in himalaya)
3334

3435

3536
**References:**

tutorials/movies_4T/README.rst

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@ publication [5]_, and the CRCNS data set [6]_.
2626
**Requirements:**
2727
This tutorial requires the following Python packages:
2828

29-
- numpy (for the data array)
30-
- scipy (for motion energy extraction)
31-
- h5py (for loading the data files)
32-
- scikit-learn (for preprocessing and modeling)
33-
- himalaya (for modeling)
34-
- pymoten (for extracting motion energy)
29+
- numpy
30+
- scipy
31+
- h5py
32+
- scikit-learn
3533
- voxelwise (this repository)
34+
- himalaya
35+
- pymoten (optional, for extracting motion energy)
36+
- cupy/pytorch (optional, to use GPU in himalaya)
3637

3738
**References:**
3839

voxelwise/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import scipy.stats
23
from sklearn.utils.validation import check_random_state
34

45

@@ -43,3 +44,34 @@ def generate_leave_one_run_out(n_samples, run_onsets, random_state=None,
4344
[runs[jj] for jj in range(n_runs) if jj not in val_runs])
4445
val = np.hstack([runs[jj] for jj in range(n_runs) if jj in val_runs])
4546
yield train, val
47+
48+
49+
def explainable_variance(data, bias_correction=True, do_zscore=True):
50+
"""Compute explainable variance for a set of voxels.
51+
52+
Parameters
53+
----------
54+
data : array of shape (n_repeats, n_times, n_voxels)
55+
fMRI reponses of the repeated test set.
56+
bias_correction: bool
57+
Perform bias correction based on the number of repetitions.
58+
do_zscore: bool
59+
z-score the data in time. Only set to False if your data time courses
60+
are already z-scored.
61+
62+
Returns
63+
-------
64+
ev : array of shape (n_voxels, )
65+
Explainable variance per voxel.
66+
"""
67+
if do_zscore:
68+
data = scipy.stats.zscore(data, axis=1)
69+
70+
mean_var = data.var(axis=1, dtype=np.float64, ddof=1).mean(axis=0)
71+
var_mean = data.mean(axis=0).var(axis=0, dtype=np.float64, ddof=1)
72+
ev = var_mean / mean_var
73+
74+
if bias_correction:
75+
n_repeats = data.shape[0]
76+
ev = ev - (1 - ev) / (n_repeats - 1)
77+
return ev

0 commit comments

Comments
 (0)