Skip to content

Commit 3a27de2

Browse files
committed
Update estimate_r function
1 parent ad7c608 commit 3a27de2

5 files changed

Lines changed: 42 additions & 38 deletions

File tree

causarray/DR_estimation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def cross_fitting(
6969
pi_hat : array
7070
Estimated propensity score.
7171
'''
72-
func_ps, params_ps = _get_func_ps(ps_model, verbose=verbose, **kwargs)
72+
func_ps, params_ps = _get_func_ps(ps_model, verbose=False, **kwargs)
7373
params_glm = _filter_params(fit_glm, {**kwargs, 'verbose': verbose})
7474

7575
if verbose:

causarray/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.2"
1+
__version__ = "0.0.3"

causarray/gcate.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def estimate_r(Y, X, A, r_max, c=1.,
195195
df_r : DataFrame
196196
Results of the number of latent factors.
197197
'''
198-
X = np.hstack((X, A))
199198
a, d = A.shape[1], X.shape[1]
199+
X = np.hstack((X, A))
200200
n, p = Y.shape
201201

202202
Y, kwargs_glm, _ = _check_input(Y, X, family, disp_glm, disp_family, offset, None, **kwargs)
@@ -210,28 +210,37 @@ def estimate_r(Y, X, A, r_max, c=1.,
210210
r_list = np.arange(1, int(r_max)+1)
211211
else:
212212
r_list = np.array(r_max, dtype=int)
213-
214-
for r in r_list:
215-
res_1, res_2 = estimate(Y, X, r, a,
216-
0, kwargs_glm, kwargs_ls_1, kwargs_es_1, kwargs_ls_2, kwargs_es_2, **kwargs)
217-
A01, A02, A1, A2 = res_1['X_U'], res_1['B_Gamma'], res_2['X_U'], res_2['B_Gamma']
218-
219-
logh = log_h(Y, family, nuisance)
220-
221-
if r==1:
222-
ll = 2 * (
223-
nll(Y, A01, A02, family, nuisance, size_factor) / p
224-
- np.sum(logh) / (n*p) )
225-
nu = (d+a) * np.maximum(n,p) * np.log(n * p / np.maximum(n,p)) / (n*p)
226-
jic = ll + c * nu
227-
res.append([0, ll, nu, jic])
228-
213+
r_max = np.max(r_list)
214+
215+
# Estimate the residual deviance
216+
res_glm = fit_glm(Y, X, offset=np.log(size_factor[:,0]), family=family, disp_glm=nuisance[0], maxiter=100, verbose=False)
217+
u, s, vt = svds(res_glm[-1], k=r_max)
218+
if u.shape[1]<r_max:
219+
raise ValueError(f'The number of latent factors is larger than the rank of deviance residuals ({u.shape[1]}). Try to decrease the value of r.')
220+
Q, _ = sp.linalg.qr(X, mode='economic')
221+
P1 = np.identity(n) - Q @ Q.T
222+
P1 = P1.astype(type_f)
223+
A1 = np.c_[X, P1 @ u]
224+
225+
logh = log_h(Y, family, nuisance)
226+
ll = 2 * (
227+
nll(Y, X, res_glm[0], family, nuisance, size_factor) / p
228+
- np.sum(logh) / (n*p) )
229+
nu = (d+a) * np.maximum(n,p) * np.log(n * p / np.maximum(n,p)) / (n*p)
230+
jic = ll + c * nu
231+
res.append([0, ll, nu, jic])
232+
233+
for r in r_list[::-1]:
234+
_, res_2 = estimate(Y, X, r, a,
235+
0, kwargs_glm, kwargs_ls_1, kwargs_es_1, kwargs_ls_2, kwargs_es_2, A=A1[:,:d+a+r], **kwargs)
236+
A1, A2 = res_2['X_U'], res_2['B_Gamma']
237+
229238
ll = 2 * (
230239
nll(Y, A1, A2, family, nuisance, size_factor) / p
231240
- np.sum(logh) / (n*p) )
232241
nu = (d + a + r) * np.maximum(n,p) * np.log(n * p / np.maximum(n,p)) / (n*p)
233242
jic = ll + c * nu
234243
res.append([r, ll, nu, jic])
235244

236-
df_r = pd.DataFrame(res, columns=['r', 'deviance', 'nu', 'JIC'])
245+
df_r = pd.DataFrame(res, columns=['r', 'deviance', 'nu', 'JIC']).sort_values(by='r')
237246
return df_r

causarray/gcate_glm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,7 @@ def fit_model(j, Y, X, offsets, family, disp, impute, alpha):
156156

157157

158158
results = Parallel(n_jobs=n_jobs)(delayed(fit_model)(
159-
j, Y, X, offsets, family, disp_glm, impute, alpha) for j in tqdm(range(Y.shape[1])))
160-
pprint.pprint('Fitting GLM done.')
159+
j, Y, X, offsets, family, disp_glm, impute, alpha) for j in tqdm(range(Y.shape[1]), disable=not verbose))
161160
if verbose: pprint.pprint('Fitting GLM done.')
162161

163162
B, Yhat_0, Yhat_1, resid_deviance = zip(*results)

causarray/gcate_opt.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -208,39 +208,35 @@ def alter_min(
208208
a = d
209209

210210
# initialization for Theta = A @ B^T
211-
if A is None or B is None:
211+
if A is None:
212212
if verbose:
213213
pprint.pprint('Estimating initial latent variables with GLMs...')
214-
res_glm = fit_glm(Y, X, offset=np.log(size_factor[:,0]), family=family, disp_glm=nuisance[0], maxiter=100)
214+
res_glm = fit_glm(Y, X, offset=np.log(size_factor[:,0]), family=family, disp_glm=nuisance[0], maxiter=100, verbose=verbose)
215215
u, s, vt = svds(res_glm[-1], k=r)
216216

217217
if u.shape[1]<r:
218218
raise ValueError(f'The number of latent factors is larger than the rank of deviance residuals ({u.shape[1]}). Try to decrease the value of r.')
219+
220+
A = np.c_[X, P1 @ u]
221+
else:
222+
assert A.shape[1] == d+r
219223

224+
if B is None:
220225
if verbose:
221226
pprint.pprint('Estimating initial coefficients with GLMs...')
222-
A = np.c_[X, P1 @ u]
223-
B = fit_glm(Y, A, offset=np.log(size_factor[:,0]), family=family, disp_glm=nuisance[0], maxiter=100)[0]
227+
228+
B = fit_glm(Y, A, offset=np.log(size_factor[:,0]), family=family, disp_glm=nuisance[0], maxiter=100, verbose=verbose)[0]
224229

225230
E = A[:, -r:] @ B[:, -r:].T
226231
u, s, vh = sp.sparse.linalg.svds(E, k=r)
227232
A[:, d:] = u * s[None,:]**(1/2)
228233
B[:, d:] = vh.T * s[None,:]**(1/2)
229234
del E, u, s, vh
230235

231-
# if offset==1:
232-
# scale = np.sqrt(np.median(np.abs(X[:,0])))
233-
# B[:, :offset] = scale
234-
# A[:, :offset] /= scale
235236

236237
if P2 is not None:
237238
P2 = P2.astype(type_f)
238-
# E = A[:,d-a:] @ B[:,d-a:].T @ (np.identity(p) - P2)
239-
# u, s, vh = sp.sparse.linalg.svds(E, k=r)
240239
B[:, d-a:d] = P2 @ B[:, d-a:d]
241-
# A[:, d:] = u * s[None,:]**(1/2)
242-
# B[:, d:] = vh.T * s[None,:]**(1/2)
243-
# del E, u, s, vh
244240

245241

246242
Y = Y.astype(type_f)
@@ -265,10 +261,10 @@ def alter_min(
265261
kwargs_ls['alpha'] = kwargs_ls['alpha']
266262
if verbose:
267263
pprint.pprint({'kwargs_glm':kwargs_glm,'kwargs_ls':kwargs_ls,'kwargs_es':kwargs_es}, compact=True)
268-
264+
pprint.pprint(f'Fitting GCATE (step {1 if P1 is None else 2})...')
269265
hist = [func_val_pre]
270266
es = Early_Stopping(**kwargs_es)
271-
with tqdm(np.arange(kwargs_es['max_iters'])) as pbar:
267+
with tqdm(np.arange(kwargs_es['max_iters']), disable=not verbose) as pbar:
272268
for t in pbar:
273269
func_val, A, B = update(
274270
Y, A, B, d, weights, P1, P2,
@@ -281,7 +277,7 @@ def alter_min(
281277
pprint.pprint('Encountered large or infinity values. Try to decrease the value of C for the norm constraints.')
282278
break
283279
elif es(func_val):
284-
pbar.set_postfix_str('Early stopped.' + es.info)
280+
pbar.set_postfix_str('Early stopped. ' + es.info)
285281
pbar.close()
286282
break
287283
else:

0 commit comments

Comments
 (0)