Skip to content

Commit 16fb8cb

Browse files
jeongyoonleeclaude
andcommitted
Add input validation to auuc_score for missing model columns (#858)
auuc_score (and the underlying get_cumlift) treats every DataFrame column not in {outcome_col, treatment_col, treatment_effect_col} as a model prediction column. When users pass a DataFrame without any model columns, the function fails with a confusing error. Now raises a clear ValueError explaining the expected DataFrame format. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d688549 commit 16fb8cb

1 file changed

Lines changed: 12 additions & 1 deletion

File tree

causalml/metrics/visualize.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,9 @@ def auuc_score(
788788
"""Calculate the AUUC (Area Under the Uplift Curve) score.
789789
790790
Args:
791-
df (pandas.DataFrame): a data frame with model estimates and actual data as columns
791+
df (pandas.DataFrame): a data frame with model estimates and actual data as columns.
792+
Columns not matching outcome_col, treatment_col, or treatment_effect_col are
793+
treated as model prediction columns whose AUUC will be computed.
792794
outcome_col (str, optional): the column name for the actual outcome
793795
treatment_col (str, optional): the column name for the treatment indicator (0 or 1)
794796
treatment_effect_col (str, optional): the column name for the true treatment effect
@@ -797,6 +799,15 @@ def auuc_score(
797799
Returns:
798800
(float): the AUUC score
799801
"""
802+
required_cols = {outcome_col, treatment_col, treatment_effect_col}
803+
model_names = [x for x in df.columns if x not in required_cols]
804+
if len(model_names) == 0:
805+
raise ValueError(
806+
"No model prediction columns found in the DataFrame. "
807+
"The DataFrame should contain at least one column besides "
808+
f"'{outcome_col}', '{treatment_col}', and '{treatment_effect_col}' "
809+
"representing model uplift predictions to score."
810+
)
800811

801812
if not tmle:
802813
cumgain = get_cumgain(

0 commit comments

Comments
 (0)