77import scanpy as sc
88from scipy .stats import wilcoxon
99from statsmodels .stats .multitest import multipletests
10+ from sklearn .neighbors import NearestNeighbors
1011
1112
1213def run_cinemaot (Y , A , raw = False , weighted = False , thres = 0.15 , smoothness = 1e-3 , ** kwargs ):
@@ -69,3 +70,69 @@ def run_cinemaot(Y, A, raw=False, weighted=False, thres=0.15, smoothness=1e-3, *
6970
7071 df = pd .DataFrame ({'stat' :stat , 'pvalue' :pvalue , 'padj' :padj })
7172 return {'cinemaot.df' :df , 'cinemaot.res' :{'W' :cf , 'Y_hat_0' :Y_hat_0 , 'Y_hat_1' :Y_hat_1 , 'de' :de }}
73+
74+
75+
76+ def run_mixscape (Y , A , raw = False , nn = 20 , ** kwargs ):
77+ '''
78+ The Python wrapper for running Mixscape.
79+
80+ Parameters
81+ ----------
82+ Y : np.ndarray
83+ The (preprocessed) gene expression matrix of shape [n,p].
84+ A : np.ndarray
85+ The treatment vector of shape [n,].
86+ raw : bool
87+ Whether the input data is raw or preprocessed.
88+ nn : int
89+ The number of nearest neighbors to use.
90+
91+ Returns
92+ -------
93+ A dictionary containing the results dataframe and other matrices.
94+ '''
95+ adata = sc .AnnData (Y .copy ())
96+ if A .ndim > 1 :
97+ A = A [:, 0 ]
98+ adata .obs ['A' ] = A
99+
100+ if raw :
101+ sc .pp .normalize_total (adata , target_sum = 1e4 )
102+ sc .pp .log1p (adata )
103+ sc .pp .scale (adata , max_value = 10 )
104+
105+ sc .pp .pca (adata )
106+
107+ Y_hat_0 = np .zeros_like (Y )
108+ Y_hat_1 = np .zeros_like (Y )
109+ de = np .zeros_like (Y )
110+
111+ # Calculate counterfactual for treated group (A=1)
112+ X_pca_ctrl = adata .obsm ['X_pca' ][adata .obs ['A' ] == 0 , :]
113+ X_pca_trt = adata .obsm ['X_pca' ][adata .obs ['A' ] == 1 , :]
114+ nbrs_ctrl = NearestNeighbors (n_neighbors = nn , algorithm = 'ball_tree' ).fit (X_pca_ctrl )
115+ mixscape_matrix_trt = nbrs_ctrl .kneighbors_graph (X_pca_trt ).toarray ()
116+
117+ Y_hat_0 [A == 1 ] = (mixscape_matrix_trt / np .sum (mixscape_matrix_trt , axis = 1 , keepdims = True )) @ Y [A == 0 ]
118+ Y_hat_0 [A == 0 ] = Y [A == 0 ]
119+
120+ # Calculate counterfactual for control group (A=0)
121+ X_pca_trt = adata .obsm ['X_pca' ][adata .obs ['A' ] == 1 , :]
122+ X_pca_ctrl = adata .obsm ['X_pca' ][adata .obs ['A' ] == 0 , :]
123+ nbrs_trt = NearestNeighbors (n_neighbors = nn , algorithm = 'ball_tree' ).fit (X_pca_trt )
124+ mixscape_matrix_ctrl = nbrs_trt .kneighbors_graph (X_pca_ctrl ).toarray ()
125+
126+ Y_hat_1 [A == 0 ] = (mixscape_matrix_ctrl / np .sum (mixscape_matrix_ctrl , axis = 1 , keepdims = True )) @ Y [A == 1 ]
127+ Y_hat_1 [A == 1 ] = Y [A == 1 ]
128+
129+ de [A == 1 ] = Y [A == 1 ] - Y_hat_0 [A == 1 ]
130+ de [A == 0 ] = Y_hat_1 [A == 0 ] - Y [A == 0 ]
131+
132+ stat , pvalue = list (zip (* [wilcoxon (de [:, j ], zero_method = 'zsplit' ) for j in range (de .shape [1 ])]))
133+ padj = multipletests (pvalue , alpha = 0.05 , method = 'fdr_bh' )[1 ]
134+
135+ df = pd .DataFrame ({'stat' : stat , 'pvalue' : pvalue , 'padj' : padj })
136+
137+ # The concept of a coupling matrix 'W' is specific to OT, returning None.
138+ return {'mixscape.df' : df , 'mixscape.res' : {'W' : adata .obsm ['X_pca' ], 'Y_hat_0' : Y_hat_0 , 'Y_hat_1' : Y_hat_1 , 'de' : de }}
0 commit comments