Skip to content

Commit 6f26f5e

Browse files
committed
Update v0.0.4
Add an option for unequal variance inference. Add an option for boolean mask for propensity estimation samples.
1 parent 0cbadd4 commit 6f26f5e

6 files changed

Lines changed: 55 additions & 23 deletions

File tree

causarray/DR_estimation.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _get_func_ps(ps_model, **kwargs):
3333
def cross_fitting(
3434
Y, A, X, X_A, family='poisson', K=1, glm_alpha=1e-4,
3535
ps_model='logistic',
36-
pi_hat=None, Y_hat=None, verbose=False, **kwargs):
36+
Y_hat=None, pi_hat=None, mask=None, verbose=False, **kwargs):
3737
'''
3838
Cross-fitting for causal estimands.
3939
@@ -55,10 +55,16 @@ def cross_fitting(
5555
The regularization parameter for the generalized linear model. The default is 1e-4.
5656
ps_model : str, optional
5757
The propensity score model. The default is 'logistic'.
58-
pi_hat : array, optional
59-
Propensity score of shape (n, a). The default is None.
58+
6059
Y_hat : array, optional
6160
Estimated potential outcome of shape (n, p, a, 2). The default is None.
61+
pi_hat : array, optional
62+
Propensity score of shape (n, a). The default is None.
63+
mask : array, optional
64+
Boolean mask of shape (n, a) for the treatment, indicating which samples are used for
65+
the estimation of the estimand. This does not affect the estimation of pseudo-outcomes
66+
and propensity scores.
67+
6268
**kwargs : dict
6369
Additional arguments to pass to the model.
6470
@@ -104,7 +110,12 @@ def cross_fitting(
104110
pi = np.zeros_like(A_test, dtype=float)
105111
for j in range(A.shape[1]):
106112
i_case = (A_train[:,j] == 1.)
107-
i_cells = i_ctrl | i_case
113+
114+
if mask is not None:
115+
i_cells = mask[:, j]
116+
else:
117+
i_ctrl = (np.sum(A_train, axis=1) == 0.)
118+
i_cells = i_ctrl | i_case
108119

109120
if ps_model=='logistic' and XA_train.shape[1]==1 and np.all(XA_train==1):
110121
prob = np.sum(i_case)/np.sum(i_cells)

causarray/DR_learner.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,16 @@ def compute_causal_estimand(
6565
'''
6666
reset_random_seeds(random_state)
6767

68-
kwargs = {k:v for k,v in kwargs.items() if k not in
69-
['kwargs_ls_1', 'kwargs_ls_2', 'kwargs_es_1', 'kwargs_es_2', 'c1', 'num_d']
70-
}
71-
68+
# check the input data
7269
if isinstance(Y, pd.DataFrame):
7370
gene_names = Y.columns
7471
Y = Y.values
7572
else:
7673
gene_names = range(Y.shape[1])
74+
Y = Y.astype('float')
7775
n, p = Y.shape
7876

79-
if len(A.shape) == 1:
80-
A = A.reshape(-1,1)
77+
if A.ndim == 1: A = A[:, None]
8178
if isinstance(A, pd.DataFrame):
8279
trt_names = A.columns
8380
A = A.values
@@ -97,7 +94,10 @@ def compute_causal_estimand(
9794
if len(mask.shape) == 1: mask = mask.reshape(-1,1)
9895
if mask.shape != A.shape:
9996
raise ValueError('Mask must have the same shape as the treatment matrix')
100-
97+
98+
kwargs = {k:v for k,v in kwargs.items() if k not in
99+
['kwargs_ls_1', 'kwargs_ls_2', 'kwargs_es_1', 'kwargs_es_2', 'c1', 'num_d']
100+
}
101101

102102
if verbose:
103103
d_A = W_A.shape[1]
@@ -113,10 +113,9 @@ def compute_causal_estimand(
113113
else:
114114
offset = None
115115
size_factors = np.ones(n)
116-
117-
Y = Y.astype('float')
116+
118117
Y_hat, pi_hat = cross_fitting(Y, A, W, W_A, family=family, offset=offset,
119-
Y_hat=Y_hat, pi_hat=pi_hat, random_state=random_state, verbose=verbose, **kwargs)
118+
Y_hat=Y_hat, pi_hat=pi_hat, mask=mask, random_state=random_state, verbose=verbose, **kwargs)
120119
pi_hat = pi_hat.reshape(*A.shape)
121120

122121
if verbose: pprint.pprint('Estimating AIPW mean...')
@@ -201,6 +200,9 @@ def LFC(
201200
Boolean mask of shape (n, a) for the treatment, indicating which samples are used for
202201
the estimation of the estimand. This does not affect the estimation of pseudo-outcomes
203202
and propensity scores.
203+
usevar : str
204+
The method to use for estimating the variance of treatment effects.
205+
Options are 'pooled' (default) or 'unequal'.
204206
205207
thres_min : float
206208
The minimum threshold for the treatment effect.

causarray/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.3"
1+
__version__ = "0.0.4"

causarray/gcate.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def fit_gcate(Y, X, A, r, family='nb', disp_glm=None, disp_family=None, offset=T
8282
kwargs : dict
8383
Additional keyword arguments.
8484
'''
85-
85+
if X.ndim == 1: X = X[:, None]
86+
if A.ndim == 1: A = A[:, None]
8687
X = np.hstack((X, A))
8788
a = A.shape[1]
8889
Y, kwargs_glm, lam1 = _check_input(Y, X, family, disp_glm, disp_family, offset, c1, **kwargs)
@@ -195,6 +196,8 @@ def estimate_r(Y, X, A, r_max, c=1.,
195196
df_r : DataFrame
196197
Results of the number of latent factors.
197198
'''
199+
if X.ndim == 1: X = X[:, None]
200+
if A.ndim == 1: A = A[:, None]
198201
a, d = A.shape[1], X.shape[1]
199202
X = np.hstack((X, A))
200203
n, p = Y.shape

causarray/gcate_glm.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,29 @@ def fit_glm(Y, X, A=None, family='gaussian', disp_family='poisson',
4141
A : array
4242
n x 1 vector of treatments or None
4343
family : str
44-
family of GLM to fit, can be one of: 'gaussian', 'poisson', 'nb'
44+
Family of GLM to fit, can be one of: 'gaussian', 'poisson', 'nb'
4545
disp_glm : array or None
46-
dispersion parameter for negative binomial GLM
47-
return_df : bool
48-
whether to return results as DataFrame
49-
impute : bool
50-
whether to impute potential outcomes and get predicted values
46+
Dispersion parameter for negative binomial GLM.
47+
impute : bool or None
48+
Whether to impute missing values in Y.
5149
offset : bool
52-
whether to use log of sum of Y as offset
50+
Whether to use log of sum of Y as offset.
51+
shrinkage : bool
52+
Whether to use regularized GLM.
53+
alpha : float
54+
Regularization parameter for regularized GLM.
55+
maxiter : int
56+
Maximum number of iterations for GLM fitting.
57+
thres_disp : float
58+
Threshold for dispersion parameter for negative binomial GLM.
59+
n_jobs : int
60+
Number of jobs to run in parallel.
61+
random_state : int
62+
Random seed for reproducibility.
63+
verbose : bool
64+
Whether to print progress messages.
65+
kwargs : dict
66+
Additional arguments to pass to GLM fitting.
5367
5468
Returns
5569
-------

causarray/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def prep_causarray_data(Y, A, X=None, X_A=None, intercept=True):
4949
Y = np.minimum(Y, np.round(np.quantile(np.max(Y, 0), 0.999)))
5050
if not isinstance(A, pd.DataFrame):
5151
A = np.asarray(A)
52+
if A.ndim == 1:
53+
A = A[:, None]
5254

5355
X = np.zeros((Y.shape[0], 0)) if X is None else np.asarray(X)
5456
X_A = X if X_A is None else np.asarray(X_A)

0 commit comments

Comments
 (0)