Skip to content

Commit 456cead

Browse files
committed
Add mixscape
1 parent f1bb151 commit 456cead

2 files changed

Lines changed: 68 additions & 1 deletion

File tree

paper/methods/cinemaot.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import scanpy as sc
88
from scipy.stats import wilcoxon
99
from statsmodels.stats.multitest import multipletests
10+
from sklearn.neighbors import NearestNeighbors
1011

1112

1213
def run_cinemaot(Y, A, raw=False, weighted=False, thres=0.15, smoothness=1e-3, **kwargs):
@@ -69,3 +70,69 @@ def run_cinemaot(Y, A, raw=False, weighted=False, thres=0.15, smoothness=1e-3, *
6970

7071
df = pd.DataFrame({'stat':stat, 'pvalue':pvalue, 'padj':padj})
7172
return {'cinemaot.df':df, 'cinemaot.res':{'W':cf, 'Y_hat_0':Y_hat_0, 'Y_hat_1':Y_hat_1, 'de':de}}
73+
74+
75+
76+
def run_mixscape(Y, A, raw=False, nn=20, **kwargs):
77+
'''
78+
The Python wrapper for running Mixscape.
79+
80+
Parameters
81+
----------
82+
Y : np.ndarray
83+
The (preprocessed) gene expression matrix of shape [n,p].
84+
A : np.ndarray
85+
The treatment vector of shape [n,].
86+
raw : bool
87+
Whether the input data is raw or preprocessed.
88+
nn : int
89+
The number of nearest neighbors to use.
90+
91+
Returns
92+
-------
93+
A dictionary containing the results dataframe and other matrices.
94+
'''
95+
adata = sc.AnnData(Y.copy())
96+
if A.ndim > 1:
97+
A = A[:, 0]
98+
adata.obs['A'] = A
99+
100+
if raw:
101+
sc.pp.normalize_total(adata, target_sum=1e4)
102+
sc.pp.log1p(adata)
103+
sc.pp.scale(adata, max_value=10)
104+
105+
sc.pp.pca(adata)
106+
107+
Y_hat_0 = np.zeros_like(Y)
108+
Y_hat_1 = np.zeros_like(Y)
109+
de = np.zeros_like(Y)
110+
111+
# Calculate counterfactual for treated group (A=1)
112+
X_pca_ctrl = adata.obsm['X_pca'][adata.obs['A'] == 0, :]
113+
X_pca_trt = adata.obsm['X_pca'][adata.obs['A'] == 1, :]
114+
nbrs_ctrl = NearestNeighbors(n_neighbors=nn, algorithm='ball_tree').fit(X_pca_ctrl)
115+
mixscape_matrix_trt = nbrs_ctrl.kneighbors_graph(X_pca_trt).toarray()
116+
117+
Y_hat_0[A == 1] = (mixscape_matrix_trt / np.sum(mixscape_matrix_trt, axis=1, keepdims=True)) @ Y[A == 0]
118+
Y_hat_0[A == 0] = Y[A == 0]
119+
120+
# Calculate counterfactual for control group (A=0)
121+
X_pca_trt = adata.obsm['X_pca'][adata.obs['A'] == 1, :]
122+
X_pca_ctrl = adata.obsm['X_pca'][adata.obs['A'] == 0, :]
123+
nbrs_trt = NearestNeighbors(n_neighbors=nn, algorithm='ball_tree').fit(X_pca_trt)
124+
mixscape_matrix_ctrl = nbrs_trt.kneighbors_graph(X_pca_ctrl).toarray()
125+
126+
Y_hat_1[A == 0] = (mixscape_matrix_ctrl / np.sum(mixscape_matrix_ctrl, axis=1, keepdims=True)) @ Y[A == 1]
127+
Y_hat_1[A == 1] = Y[A == 1]
128+
129+
de[A == 1] = Y[A == 1] - Y_hat_0[A == 1]
130+
de[A == 0] = Y_hat_1[A == 0] - Y[A == 0]
131+
132+
stat, pvalue = list(zip(*[wilcoxon(de[:, j], zero_method='zsplit') for j in range(de.shape[1])]))
133+
padj = multipletests(pvalue, alpha=0.05, method='fdr_bh')[1]
134+
135+
df = pd.DataFrame({'stat': stat, 'pvalue': pvalue, 'padj': padj})
136+
137+
# The concept of a coupling matrix 'W' is specific to OT, returning None.
138+
return {'mixscape.df': df, 'mixscape.res': {'W': adata.obsm['X_pca'], 'Y_hat_0': Y_hat_0, 'Y_hat_1': Y_hat_1, 'de': de}}

paper/methods/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def comp_stat(true, pred, c):
2525

2626

2727
def comp_score(Y, CF, celltype, Z, W_hat=None):
28-
adata = sc.AnnData(CF)
28+
adata = sc.AnnData(CF.astype('float'))
2929
# adata.obs['covariate'] = W[:,-1]
3030
# adata.obs['trt'] = A
3131

0 commit comments

Comments
 (0)