@@ -76,14 +76,14 @@ class ConditionalImputer(Imputer):
7676 def __init__ (self , missing_values = "NaN" , strategy = "mean" ,
7777 strategy_nominal = "most_frequent" ,
7878 categorical_features = None ,
79- empty_attribute_constant = None ,
79+ fill_empty = None ,
8080 axis = 0 , verbose = 0 , copy = True ):
8181 self .missing_values = missing_values
8282 self .strategy = strategy
8383 self .strategy_nominal = strategy_nominal
8484 self .categorical_features = categorical_features
8585 self .categorical_features_implied = None
86- self .empty_attribute_constant = empty_attribute_constant
86+ self .fill_empty = fill_empty
8787 self .axis = axis
8888 self .verbose = verbose
8989 self .copy = copy
@@ -157,24 +157,48 @@ def fit(self, X, y=None):
157157
158158 def transform (self , X ):
159159 """Impute all missing values in X.
160+
160161 Parameters
161162 ----------
162163 X : {array-like, sparse matrix}, shape = [n_samples, n_features]
163164 The input data to complete.
164165 """
165- check_is_fitted (self , 'statistics_' )
166- X = check_array (X , accept_sparse = 'csc' , dtype = FLOAT_DTYPES ,
167- force_all_finite = False , copy = self .copy )
168- statistics = self .statistics_
169- if X .shape [1 ] != statistics .shape [0 ]:
170- raise ValueError ("X has %d features per sample, expected %d"
171- % (X .shape [1 ], self .statistics_ .shape [0 ]))
172-
173- # impute completelly empty columns with constant
174- if self .empty_attribute_constant is not None :
175- invalid_mask = np .isnan (statistics )
176- X [:, invalid_mask ] = self .empty_attribute_constant
177- statistics [invalid_mask ] = self .empty_attribute_constant
166+ if self .axis == 0 :
167+ check_is_fitted (self , 'statistics_' )
168+ X = check_array (X , accept_sparse = 'csc' , dtype = FLOAT_DTYPES ,
169+ force_all_finite = False , copy = self .copy )
170+ statistics = self .statistics_ .copy ()
171+ if X .shape [1 ] != statistics .shape [0 ]:
172+ raise ValueError ("X has %d features per sample, expected %d"
173+ % (X .shape [1 ], self .statistics_ .shape [0 ]))
174+
175+ # Since two different arrays can be provided in fit(X) and
176+ # transform(X), the imputation data need to be recomputed
177+ # when the imputation is done per sample
178+ else :
179+ X = check_array (X , accept_sparse = 'csr' , dtype = FLOAT_DTYPES ,
180+ force_all_finite = False , copy = self .copy )
181+
182+ if sparse .issparse (X ):
183+ statistics = self ._sparse_fit (X ,
184+ self .strategy ,
185+ self .missing_values ,
186+ self .axis )
187+
188+ else :
189+ statistics = self ._dense_fit (X ,
190+ self .strategy ,
191+ self .missing_values ,
192+ self .axis )
193+
194+ # impute completelly empty columns with constant, if
195+ # `fill_empty' parameter was set
196+ if self .fill_empty is not None :
197+ if sparse .issparse (X ):
198+ X = X .toarray ()
199+ empty_mask = np .all (_get_mask (X , self .missing_values ),
200+ axis = self .axis )
201+ statistics [empty_mask ] = self .fill_empty
178202
179203 # Delete the invalid rows/columns
180204 invalid_mask = np .isnan (statistics )
@@ -183,11 +207,14 @@ def transform(self, X):
183207 valid_statistics_indexes = np .where (valid_mask )[0 ]
184208 missing = np .arange (X .shape [not self .axis ])[invalid_mask ]
185209
186- if invalid_mask .any ():
210+ if self . axis == 0 and invalid_mask .any ():
187211 if self .verbose :
188212 warnings .warn ("Deleting features without "
189213 "observed values: %s" % missing )
190214 X = X [:, valid_statistics_indexes ]
215+ elif self .axis == 1 and invalid_mask .any ():
216+ raise ValueError ("Some rows only contain "
217+ "missing values: %s" % missing )
191218
192219 # Do actual imputation
193220 if sparse .issparse (X ) and self .missing_values != 0 :
@@ -205,7 +232,10 @@ def transform(self, X):
205232 n_missing = np .sum (mask , axis = self .axis )
206233 values = np .repeat (valid_statistics , n_missing )
207234
208- coordinates = np .where (mask .transpose ())[::- 1 ]
235+ if self .axis == 0 :
236+ coordinates = np .where (mask .transpose ())[::- 1 ]
237+ else :
238+ coordinates = mask
209239
210240 X [coordinates ] = values
211241
0 commit comments