Skip to content

Commit 18a4599

Browse files
committed
Update causarray
1 parent 456cead commit 18a4599

3 files changed

Lines changed: 87 additions & 37 deletions

File tree

causarray/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.4"
1+
__version__ = "0.0.5"

paper/methods/causarray/DR_estimation.py

Lines changed: 84 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import numpy as np
22
from sklearn.linear_model import LogisticRegression
3-
from sklearn.ensemble import RandomForestClassifier
3+
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
4+
from sklearn_ensemble_cv import reset_random_seeds, Ensemble, ECV
45
from causarray.gcate_glm import fit_glm
56
from causarray.utils import *
67
from causarray.utils import _filter_params
8+
from joblib import Parallel, delayed
9+
from tqdm import tqdm
710
import pprint
811

912
from sklearn.model_selection import KFold, ShuffleSplit
@@ -82,10 +85,15 @@ def cross_fitting(
8285
pprint.pprint(params_ps)
8386
pprint.pprint(params_glm)
8487

85-
if K>1:
86-
# Initialize KFold cross-validator
87-
kf = KFold(n_splits=K, random_state=0, shuffle=True)
88-
folds = kf.split(X)
88+
if K > 1:
89+
n_samples = X.shape[0]
90+
if K >= n_samples:
91+
# Use Leave-One-Out Cross-Validation
92+
folds = [([i for i in range(n_samples) if i != j], [j]) for j in range(n_samples)]
93+
else:
94+
# Initialize KFold cross-validator
95+
kf = KFold(n_splits=int(K), random_state=0, shuffle=True)
96+
folds = kf.split(X)
8997
else:
9098
folds = [(np.arange(X.shape[0]), np.arange(X.shape[0]))]
9199

@@ -95,6 +103,16 @@ def cross_fitting(
95103
fit_Y = True if Y_hat is None else False
96104
Y_hat = np.zeros((Y.shape[0],Y.shape[1],A.shape[1],2), dtype=float) if fit_Y else Y_hat
97105

106+
# perform ECV at once
107+
if fit_pi and ps_model == 'random_forest_cv':
108+
info_ecv = run_ecv(X_A, A, **params_ps)
109+
func_ps, params_ps = _get_func_ps(ps_model, verbose=False, ecv=False,
110+
kwargs_ensemble=info_ecv['best_params_ensemble'], kwargs_regr=info_ecv['best_params_regr'])
111+
pprint.pprint('Best parameters for the regression model:')
112+
pprint.pprint(info_ecv['best_params_regr'])
113+
pprint.pprint('Best parameters for the ensemble model:')
114+
pprint.pprint(info_ecv['best_params_ensemble'])
115+
98116
# Perform cross-fitting
99117
for train_index, test_index in folds:
100118
# Split data
@@ -178,8 +196,6 @@ def AIPW_mean(Y, A, mu, pi, positive=False):
178196
tau = np.mean(pseudo_y, axis=0)
179197

180198
return tau, pseudo_y
181-
182-
183199

184200

185201

@@ -188,51 +204,89 @@ def AIPW_mean(Y, A, mu, pi, positive=False):
188204

189205

190206

191-
from joblib import Parallel, delayed
192-
from tqdm import tqdm
193-
from sklearn_ensemble_cv import reset_random_seeds, Ensemble, ECV
194-
from sklearn.tree import DecisionTreeRegressor
195-
196-
def fit_rf(X, y, X_test=None, sample_weight=None, M=100, M_max=1000,
207+
def run_ecv(
208+
X, y, M=200, M_max=1000,
197209
# fixed parameters for bagging regressor
198-
kwargs_ensemble={'verbose':1},
210+
kwargs_ensemble={},
199211
# fixed parameters for decision tree
200-
kwargs_regr={'min_samples_leaf': 3}, # 'min_samples_split': 10, 'max_features':'sqrt'
212+
kwargs_regr={},
201213
# grid search parameters
202-
grid_regr = {'max_depth': [11]},
203-
grid_ensemble = {'random_state': 0}, #'max_samples':np.linspace(0.25, 1., 4)
204-
):
214+
grid_regr={},
215+
grid_ensemble={}
216+
):
217+
"""
218+
Runs Ensemble Cross-Validation (ECV) to find the best hyperparameters.
219+
"""
220+
kwargs_ensemble = {**{'verbose': 1, 'bootstrap': True}, **kwargs_ensemble}
221+
kwargs_regr = {**{'min_samples_split': 20, 'min_samples_leaf': 10, 'max_features': 'sqrt', 'ccp_alpha': 0.02, 'class_weight': 'balanced'}, **kwargs_regr}
222+
grid_regr = {**{'max_depth': [3, 5, 7]}, **grid_regr}
223+
grid_ensemble = {**{'random_state': 0, 'max_samples': [0.4, 0.6, 0.8, 1.]}, **grid_ensemble}
205224

206225
# Validate integer parameters
207226
M = int(M)
208227
M_max = int(M_max)
209-
# for kwargs in [kwargs_regr, kwargs_ensemble, grid_regr, grid_ensemble]:
210-
# for param in kwargs:
211-
# if param in ['max_depth', 'random_state', 'max_leaf_nodes'] and isinstance(kwargs[param], float):
212-
# kwargs[param] = int(kwargs[param])
213228

214229
# Make sure y is 2D
215230
y = y.reshape(-1, 1) if y.ndim == 1 else y
216231

217232
# Run ECV
218-
res_ecv, info_ecv = ECV(
219-
X, y, DecisionTreeRegressor, grid_regr, grid_ensemble,
220-
kwargs_regr, kwargs_ensemble,
233+
_, info_ecv = ECV(
234+
X, y, DecisionTreeClassifier, grid_regr, grid_ensemble,
235+
kwargs_regr, kwargs_ensemble,
221236
M=M, M0=M, M_max=M_max, return_df=True
222237
)
223238

224239
# Replace the in-sample best parameter for 'n_estimators' with extrapolated best parameter
225240
info_ecv['best_params_ensemble']['n_estimators'] = info_ecv['best_n_estimators_extrapolate']
226241

242+
return info_ecv
243+
244+
245+
def fit_rf(
246+
X, y, X_test=None, M=100, M_max=1000, ecv=True,
247+
# fixed parameters for bagging regressor
248+
kwargs_ensemble={},
249+
# fixed parameters for decision tree
250+
kwargs_regr={},
251+
# grid search parameters
252+
grid_regr={},
253+
grid_ensemble={}
254+
):
255+
"""
256+
Fits a Random Forest model using parameters found by ECV.
257+
"""
258+
259+
kwargs_ensemble = {**{'verbose': 1, 'bootstrap': True}, **kwargs_ensemble}
260+
kwargs_regr = {**{'min_samples_split': 20, 'min_samples_leaf': 10, 'max_features': 'sqrt', 'ccp_alpha': 0.02, 'class_weight': 'balanced'}, **kwargs_regr}
261+
grid_regr = {**{'max_depth': [3, 5, 7]}, **grid_regr}
262+
grid_ensemble = {**{'random_state': 0, 'max_samples': [0.4, 0.6, 0.8, 1.]}, **grid_ensemble}
263+
264+
# Make sure y is 2D
265+
y_2d = y.reshape(-1, 1) if y.ndim == 1 else y
266+
267+
if ecv:
268+
# Get best parameters from ECV
269+
info_ecv = run_ecv(
270+
X, y_2d, M=M, M_max=M_max,
271+
kwargs_ensemble=kwargs_ensemble,
272+
kwargs_regr=kwargs_regr,
273+
grid_regr=grid_regr,
274+
grid_ensemble=grid_ensemble
275+
)
276+
params_regr = info_ecv['best_params_regr']
277+
params_ensemble = info_ecv['best_params_ensemble']
278+
else:
279+
params_regr = kwargs_regr
280+
params_ensemble = kwargs_ensemble
281+
227282
# Fit the ensemble with the best CV parameters
228283
regr = Ensemble(
229-
estimator=DecisionTreeRegressor(**info_ecv['best_params_regr']),
230-
**info_ecv['best_params_ensemble']).fit(X, y, sample_weight=sample_weight)
231-
284+
estimator=DecisionTreeClassifier(**params_regr), **params_ensemble).fit(X, y_2d)
285+
232286
# Predict
233287
if X_test is None:
234288
X_test = X
235-
return regr.predict(X_test).reshape(-1, y.shape[1])
289+
return regr.predict(X_test).reshape(-1, y_2d.shape[1])
236290

237291

238292

@@ -252,11 +306,7 @@ def fit_rf_ind_ps(X, Y, *args, **kwargs):
252306
def _fit(X, y, i_ctrl, *args, **kwargs):
253307
i_case = (y == 1.)
254308
i_cells = i_ctrl | i_case
255-
sample_weight = np.ones(y.shape[0])
256-
class_weight = len(y) / (2 * np.bincount(y.astype(int)))
257-
for a in range(2):
258-
sample_weight[y == a] = class_weight[a]
259-
return fit_rf(X[i_cells], y[i_cells], sample_weight=sample_weight[i_cells], *args, **kwargs)
309+
return fit_rf(X[i_cells], y[i_cells], *args, **kwargs)
260310

261311
Y_hat = Parallel(n_jobs=-1)(delayed(_fit)(X, Y[:,j], i_ctrl, *args, **kwargs)
262312
for j in tqdm(range(Y.shape[1])))

paper/methods/causarray/DR_learner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def compute_causal_estimand(
170170
def LFC(
171171
Y, W, A, W_A=None, family='nb', offset=False,
172172
Y_hat=None, pi_hat=None, cross_est=False, mask=None, usevar='pooled',
173-
thres_min=1e-4, thres_diff=1e-6, eps_var=1e-3,
173+
thres_min=1e-2, thres_diff=1e-2, eps_var=1e-4,
174174
fdx=False, fdx_alpha=0.05, fdx_c=0.1,
175175
verbose=False, **kwargs):
176176
'''
@@ -249,7 +249,7 @@ def estimand(etas, A, **kwargs):
249249
raise ValueError('usevar must be either "pooled" or "unequal"')
250250

251251
# filter out low-expressed genes
252-
idx = (np.maximum(tau_0,tau_1)<thres_min) & ((tau_1-tau_0)<thres_diff)
252+
idx = (np.maximum(np.abs(tau_0),np.abs(tau_1))<thres_min) | (np.abs(tau_1-tau_0)<thres_diff)
253253
tau_est[idx] = 0.; eta_est[:,idx] = 0.; var_est[idx] = np.inf
254254

255255
return eta_est, tau_est, var_est

0 commit comments

Comments
 (0)