File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 9999 if 'cudf' in str (type (y_train )):
100100 params .n_classes = y_train [y_train .columns [0 ]].nunique ()
101101 else :
102- params .n_classes = len (np .unique (y_train ))
102+ unique_y_train = np .unique (y_train )
103+ params .n_classes = len (unique_y_train )
104+ if max (unique_y_train ) != len (unique_y_train ) - 1 :
105+ params .n_classes = int (max (unique_y_train )) + 1
106+
103107 if params .n_classes > 2 :
104108 lgbm_params ['num_class' ] = params .n_classes
105109
Original file line number Diff line number Diff line change @@ -30,7 +30,7 @@ def convert_xgb_predictions(y_pred, objective):
3030 if objective == 'multi:softprob' :
3131 y_pred = convert_probs_to_classes (y_pred )
3232 elif objective == 'binary:logistic' :
33- y_pred = y_pred .astype (np .int32 )
33+ y_pred = ( y_pred >= 0.5 ) .astype (np .int32 )
3434 return y_pred
3535
3636
You can’t perform that action at this time.
0 commit comments