Skip to content

Commit 5b41f6e

Browse files
committed
minor updates of glm functions
1 parent f761486 commit 5b41f6e

4 files changed

Lines changed: 31 additions & 18 deletions

File tree

causarray/DR_estimation.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def cross_fitting(
6969
pi_hat : array
7070
Estimated propensity score.
7171
'''
72-
func_ps, params_ps = _get_func_ps(ps_model, **kwargs)
73-
params_glm = _filter_params(fit_glm, kwargs)
72+
func_ps, params_ps = _get_func_ps(ps_model, verbose=verbose, **kwargs)
73+
params_glm = _filter_params(fit_glm, {**kwargs, 'verbose': verbose})
7474

7575
if verbose:
7676
pprint.pprint(params_ps)
@@ -97,6 +97,7 @@ def cross_fitting(
9797
Y_train, Y_test = Y[train_index], Y[test_index]
9898

9999
if fit_pi:
100+
if verbose: pprint.pprint('Fit propensity score models...')
100101
i_ctrl = (np.sum(A_train, axis=1) == 0.)
101102

102103
pi = np.zeros_like(A_test, dtype=float)
@@ -111,13 +112,13 @@ def cross_fitting(
111112
else:
112113
pi[:,j] = func_ps(XA_train[i_cells], A_train[i_cells][:,j], XA_test)
113114

115+
if verbose: pprint.pprint('Fit outcome models...')
114116
# Fit GLM on training data and predict on test data
115117
res = fit_glm(Y_train, X_train, A_train, family=family, alpha=glm_alpha,
116118
impute=X_test, **params_glm)
117119

118120
# Store results
119-
if fit_pi:
120-
pi_hat[test_index] = pi
121+
if fit_pi: pi_hat[test_index] = pi
121122

122123
Y_hat[test_index,:,:,0] = res[1][0]
123124
Y_hat[test_index,:,:,1] = res[1][1]

causarray/DR_learner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def compute_causal_estimand(
109109
Y_hat=Y_hat, pi_hat=pi_hat, random_state=random_state, verbose=verbose, **kwargs)
110110
pi_hat = pi_hat.reshape(*A.shape)
111111

112+
if verbose: pprint.pprint('Estimating AIPW mean...')
112113
# point estimation of the treatment effect
113114
_, etas = AIPW_mean(Y, np.stack([1-A, A], axis=-1),
114115
Y_hat, np.stack([1-pi_hat, pi_hat], axis=-1), positive=True)

causarray/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.1"
1+
__version__ = "0.0.2"

causarray/gcate_glm.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,26 @@ def fit_glm(Y, X, A=None, family='gaussian', disp_family='poisson',
4949
impute : bool
5050
whether to impute potential outcomes and get predicted values
5151
offset : bool
52-
whether to use log of sum of Y as offset
52+
whether to use log of sum of Y as offset
53+
54+
Returns
55+
-------
56+
B : array
57+
d x p matrix of coefficients
58+
Yhat : array
59+
n x p x a matrix of predicted values
60+
disp_glm : array
61+
p x 1 vector of dispersion parameters
62+
offsets : array
63+
n x 1 vector of offsets
64+
resid_deviance : array
65+
n x p matrix of deviance residuals
5366
'''
5467
np.random.seed(random_state)
5568

5669
if family not in ['gaussian', 'poisson', 'nb']:
5770
raise ValueError('Family not recognized')
58-
71+
5972
d = X.shape[1]
6073

6174
if A is None:
@@ -70,8 +83,6 @@ def fit_glm(Y, X, A=None, family='gaussian', disp_family='poisson',
7083
X_test = X
7184
X_test = np.c_[X,np.zeros_like(A)]
7285
X = np.c_[X,A]
73-
# X_test = X.copy()
74-
7586
a = A.shape[1]
7687

7788
if offset is not None and offset is not False:
@@ -90,7 +101,7 @@ def fit_glm(Y, X, A=None, family='gaussian', disp_family='poisson',
90101
pprint.pprint('Fitting {} GLM{}...'.format(family, '' if offsets is None else ' with offset'))
91102
is_constant = np.all(X == X[0, :], axis=0)
92103
alpha[is_constant] = 0
93-
# alpha[:-a] = 0
104+
94105

95106
families = {
96107
'gaussian': lambda disp: sm.families.Gaussian(),
@@ -122,11 +133,11 @@ def fit_model(j, Y, X, offsets, family, disp, impute, alpha):
122133
Yhat_0 = np.zeros((Y.shape[0], a))
123134
Yhat_1 = np.zeros((Y.shape[0], a))
124135
if impute is not False:
125-
for j in range(a):
136+
for k in range(a):
126137
X_test_copy = X_test.copy()
127-
Yhat_0[:,j] = mod.predict(X_test_copy, offset=offsets)
128-
X_test_copy[:, d+j] = 1 # Update the j-th column with all ones
129-
Yhat_1[:,j] = mod.predict(X_test_copy, offset=offsets)
138+
Yhat_0[:,k] = mod.predict(X_test_copy, offset=offsets)
139+
X_test_copy[:, d+k] = 1
140+
Yhat_1[:,k] = mod.predict(X_test_copy, offset=offsets)
130141
else:
131142
Yhat_0[:,:] = Yhat_1[:,:] = mod.predict(X, offset=offsets).reshape(-1, a)
132143

@@ -146,9 +157,11 @@ def fit_model(j, Y, X, offsets, family, disp, impute, alpha):
146157

147158
results = Parallel(n_jobs=n_jobs)(delayed(fit_model)(
148159
j, Y, X, offsets, family, disp_glm, impute, alpha) for j in tqdm(range(Y.shape[1])))
160+
pprint.pprint('Fitting GLM done.')
161+
if verbose: pprint.pprint('Fitting GLM done.')
149162

150163
B, Yhat_0, Yhat_1, resid_deviance = zip(*results)
151-
B = np.array(B)
164+
B = np.array(B)
152165
Yhat_0 = np.array(Yhat_0).transpose(1, 0, 2)
153166
Yhat_1 = np.array(Yhat_1).transpose(1, 0, 2)
154167
resid_deviance = np.array(resid_deviance).T
@@ -171,7 +184,6 @@ def estimate_disp(Y, X=None, A=None, Y_hat=None, disp_family='gaussian', offset=
171184
else:
172185
offsets = None
173186
sf = 1.
174-
175187

176188
if Y_hat is None:
177189
if verbose:
@@ -183,8 +195,7 @@ def estimate_disp(Y, X=None, A=None, Y_hat=None, disp_family='gaussian', offset=
183195
if disp_family=='gaussian':
184196
Y_norm = Y/sf
185197
reg = LinearRegression(fit_intercept=False).fit(X, Y_norm)
186-
Y_hat = reg.predict(X)
187-
# Y_hat = np.clip(Y_hat, 0, 1) * sf
198+
Y_hat = reg.predict(X)
188199
elif disp_family=='poisson':
189200
Y_hat = fit_glm(Y, X, None, offset=offsets, family='poisson', impute=False, **kwargs)[1]
190201
Y_hat /= sf

0 commit comments

Comments
 (0)