@@ -195,8 +195,8 @@ def estimate_r(Y, X, A, r_max, c=1.,
195195 df_r : DataFrame
196196 Results of the number of latent factors.
197197 '''
198- X = np .hstack ((X , A ))
199198 a , d = A .shape [1 ], X .shape [1 ]
199+ X = np .hstack ((X , A ))
200200 n , p = Y .shape
201201
202202 Y , kwargs_glm , _ = _check_input (Y , X , family , disp_glm , disp_family , offset , None , ** kwargs )
@@ -210,28 +210,37 @@ def estimate_r(Y, X, A, r_max, c=1.,
210210 r_list = np .arange (1 , int (r_max )+ 1 )
211211 else :
212212 r_list = np .array (r_max , dtype = int )
213-
214- for r in r_list :
215- res_1 , res_2 = estimate (Y , X , r , a ,
216- 0 , kwargs_glm , kwargs_ls_1 , kwargs_es_1 , kwargs_ls_2 , kwargs_es_2 , ** kwargs )
217- A01 , A02 , A1 , A2 = res_1 ['X_U' ], res_1 ['B_Gamma' ], res_2 ['X_U' ], res_2 ['B_Gamma' ]
218-
219- logh = log_h (Y , family , nuisance )
220-
221- if r == 1 :
222- ll = 2 * (
223- nll (Y , A01 , A02 , family , nuisance , size_factor ) / p
224- - np .sum (logh ) / (n * p ) )
225- nu = (d + a ) * np .maximum (n ,p ) * np .log (n * p / np .maximum (n ,p )) / (n * p )
226- jic = ll + c * nu
227- res .append ([0 , ll , nu , jic ])
228-
213+ r_max = np .max (r_list )
214+
215+ # Estimate the residual deviance
216+ res_glm = fit_glm (Y , X , offset = np .log (size_factor [:,0 ]), family = family , disp_glm = nuisance [0 ], maxiter = 100 , verbose = False )
217+ u , s , vt = svds (res_glm [- 1 ], k = r_max )
218+ if u .shape [1 ]< r_max :
219+ raise ValueError (f'The number of latent factors is larger than the rank of deviance residuals ({ u .shape [1 ]} ). Try to decrease the value of r.' )
220+ Q , _ = sp .linalg .qr (X , mode = 'economic' )
221+ P1 = np .identity (n ) - Q @ Q .T
222+ P1 = P1 .astype (type_f )
223+ A1 = np .c_ [X , P1 @ u ]
224+
225+ logh = log_h (Y , family , nuisance )
226+ ll = 2 * (
227+ nll (Y , X , res_glm [0 ], family , nuisance , size_factor ) / p
228+ - np .sum (logh ) / (n * p ) )
229+ nu = (d + a ) * np .maximum (n ,p ) * np .log (n * p / np .maximum (n ,p )) / (n * p )
230+ jic = ll + c * nu
231+ res .append ([0 , ll , nu , jic ])
232+
233+ for r in r_list [::- 1 ]:
234+ _ , res_2 = estimate (Y , X , r , a ,
235+ 0 , kwargs_glm , kwargs_ls_1 , kwargs_es_1 , kwargs_ls_2 , kwargs_es_2 , A = A1 [:,:d + a + r ], ** kwargs )
236+ A1 , A2 = res_2 ['X_U' ], res_2 ['B_Gamma' ]
237+
229238 ll = 2 * (
230239 nll (Y , A1 , A2 , family , nuisance , size_factor ) / p
231240 - np .sum (logh ) / (n * p ) )
232241 nu = (d + a + r ) * np .maximum (n ,p ) * np .log (n * p / np .maximum (n ,p )) / (n * p )
233242 jic = ll + c * nu
234243 res .append ([r , ll , nu , jic ])
235244
236- df_r = pd .DataFrame (res , columns = ['r' , 'deviance' , 'nu' , 'JIC' ])
245+ df_r = pd .DataFrame (res , columns = ['r' , 'deviance' , 'nu' , 'JIC' ]). sort_values ( by = 'r' )
237246 return df_r
0 commit comments