@@ -65,19 +65,16 @@ def compute_causal_estimand(
6565 '''
6666 reset_random_seeds (random_state )
6767
68- kwargs = {k :v for k ,v in kwargs .items () if k not in
69- ['kwargs_ls_1' , 'kwargs_ls_2' , 'kwargs_es_1' , 'kwargs_es_2' , 'c1' , 'num_d' ]
70- }
71-
68+ # check the input data
7269 if isinstance (Y , pd .DataFrame ):
7370 gene_names = Y .columns
7471 Y = Y .values
7572 else :
7673 gene_names = range (Y .shape [1 ])
74+ Y = Y .astype ('float' )
7775 n , p = Y .shape
7876
79- if len (A .shape ) == 1 :
80- A = A .reshape (- 1 ,1 )
77+ if A .ndim == 1 : A = A [:, None ]
8178 if isinstance (A , pd .DataFrame ):
8279 trt_names = A .columns
8380 A = A .values
@@ -97,7 +94,10 @@ def compute_causal_estimand(
9794 if len (mask .shape ) == 1 : mask = mask .reshape (- 1 ,1 )
9895 if mask .shape != A .shape :
9996 raise ValueError ('Mask must have the same shape as the treatment matrix' )
100-
97+
98+ kwargs = {k :v for k ,v in kwargs .items () if k not in
99+ ['kwargs_ls_1' , 'kwargs_ls_2' , 'kwargs_es_1' , 'kwargs_es_2' , 'c1' , 'num_d' ]
100+ }
101101
102102 if verbose :
103103 d_A = W_A .shape [1 ]
@@ -113,10 +113,9 @@ def compute_causal_estimand(
113113 else :
114114 offset = None
115115 size_factors = np .ones (n )
116-
117- Y = Y .astype ('float' )
116+
118117 Y_hat , pi_hat = cross_fitting (Y , A , W , W_A , family = family , offset = offset ,
119- Y_hat = Y_hat , pi_hat = pi_hat , random_state = random_state , verbose = verbose , ** kwargs )
118+ Y_hat = Y_hat , pi_hat = pi_hat , mask = mask , random_state = random_state , verbose = verbose , ** kwargs )
120119 pi_hat = pi_hat .reshape (* A .shape )
121120
122121 if verbose : pprint .pprint ('Estimating AIPW mean...' )
@@ -201,6 +200,9 @@ def LFC(
201200 Boolean mask of shape (n, a) for the treatment, indicating which samples are used for
202201 the estimation of the estimand. This does not affect the estimation of pseudo-outcomes
203202 and propensity scores.
203+ usevar : str
204+ The method to use for estimating the variance of treatment effects.
205+ Options are 'pooled' (default) or 'unequal'.
204206
205207 thres_min : float
206208 The minimum threshold for the treatment effect.
0 commit comments