|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import pandas as pd |
| 5 | +from matplotlib import pyplot as plt |
| 6 | +import logging |
| 7 | + |
| 8 | +plt.style.use("fivethirtyeight") |
| 9 | +RANDOM_COL = "Random" |
| 10 | + |
| 11 | +logger = logging.getLogger("causalml") |
| 12 | + |
| 13 | + |
| 14 | +def get_toc( |
| 15 | + df, |
| 16 | + outcome_col="y", |
| 17 | + treatment_col="w", |
| 18 | + treatment_effect_col="tau", |
| 19 | + normalize=False, |
| 20 | + random_seed=42, |
| 21 | +): |
| 22 | + """Get the Targeting Operator Characteristic (TOC) of model estimates in population. |
| 23 | +
|
| 24 | + TOC(q) is the difference between the ATE among the top-q fraction of units ranked |
| 25 | + by the prioritization score and the overall ATE. A positive TOC at low q indicates |
| 26 | + the model successfully identifies units with above-average treatment benefit. |
| 27 | +
|
| 28 | + If the true treatment effect is provided (e.g. in synthetic data), it's used directly |
| 29 | + to calculate TOC. Otherwise, it's estimated as the difference between the mean outcomes |
| 30 | + of the treatment and control groups in each quantile band. |
| 31 | +
|
| 32 | + For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules |
| 33 | + via Rank-Weighted Average Treatment Effects`. |
| 34 | +
|
| 35 | + For the former, `treatment_effect_col` should be provided. For the latter, both |
| 36 | + `outcome_col` and `treatment_col` should be provided. |
| 37 | +
|
| 38 | + Args: |
| 39 | + df (pandas.DataFrame): a data frame with model estimates and actual data as columns |
| 40 | + outcome_col (str, optional): the column name for the actual outcome |
| 41 | + treatment_col (str, optional): the column name for the treatment indicator (0 or 1) |
| 42 | + treatment_effect_col (str, optional): the column name for the true treatment effect |
| 43 | + normalize (bool, optional): whether to normalize the y-axis to 1 or not |
| 44 | + random_seed (int, optional): deprecated |
| 45 | +
|
| 46 | + Returns: |
| 47 | + (pandas.DataFrame): TOC values of model estimates in population, indexed by quantile q |
| 48 | + """ |
| 49 | + assert ( |
| 50 | + (outcome_col in df.columns and df[outcome_col].notnull().all()) |
| 51 | + and (treatment_col in df.columns and df[treatment_col].notnull().all()) |
| 52 | + or ( |
| 53 | + treatment_effect_col in df.columns |
| 54 | + and df[treatment_effect_col].notnull().all() |
| 55 | + ) |
| 56 | + ), "{outcome_col} and {treatment_col}, or {treatment_effect_col} should be present without null.".format( |
| 57 | + outcome_col=outcome_col, |
| 58 | + treatment_col=treatment_col, |
| 59 | + treatment_effect_col=treatment_effect_col, |
| 60 | + ) |
| 61 | + |
| 62 | + df = df.copy() |
| 63 | + |
| 64 | + model_names = [ |
| 65 | + x |
| 66 | + for x in df.columns |
| 67 | + if x not in [outcome_col, treatment_col, treatment_effect_col] |
| 68 | + ] |
| 69 | + |
| 70 | + use_oracle = ( |
| 71 | + treatment_effect_col in df.columns and df[treatment_effect_col].notnull().all() |
| 72 | + ) |
| 73 | + |
| 74 | + if use_oracle: |
| 75 | + overall_ate = df[treatment_effect_col].mean() |
| 76 | + else: |
| 77 | + treated = df[treatment_col] == 1 |
| 78 | + overall_ate = ( |
| 79 | + df.loc[treated, outcome_col].mean() - df.loc[~treated, outcome_col].mean() |
| 80 | + ) |
| 81 | + |
| 82 | + n_total = len(df) |
| 83 | + quantiles = np.linspace(0, 1, n_total + 1)[1:] |
| 84 | + |
| 85 | + toc = [] |
| 86 | + for col in model_names: |
| 87 | + sorted_df = df.sort_values(col, ascending=False).reset_index(drop=True) |
| 88 | + toc_values = [] |
| 89 | + |
| 90 | + for q in quantiles: |
| 91 | + top_k = max(1, int(np.ceil(q * n_total))) |
| 92 | + top_subset = sorted_df.iloc[:top_k] |
| 93 | + |
| 94 | + if use_oracle: |
| 95 | + subset_ate = top_subset[treatment_effect_col].mean() |
| 96 | + else: |
| 97 | + t_mask = top_subset[treatment_col] == 1 |
| 98 | + c_mask = top_subset[treatment_col] == 0 |
| 99 | + if t_mask.sum() == 0 or c_mask.sum() == 0: |
| 100 | + subset_ate = overall_ate |
| 101 | + else: |
| 102 | + subset_ate = ( |
| 103 | + top_subset.loc[t_mask, outcome_col].mean() |
| 104 | + - top_subset.loc[c_mask, outcome_col].mean() |
| 105 | + ) |
| 106 | + |
| 107 | + toc_values.append(subset_ate - overall_ate) |
| 108 | + |
| 109 | + toc.append(pd.Series(toc_values, index=quantiles)) |
| 110 | + |
| 111 | + toc = pd.concat(toc, join="inner", axis=1) |
| 112 | + toc.loc[0] = np.zeros((toc.shape[1],)) |
| 113 | + toc = toc.sort_index().interpolate() |
| 114 | + toc.columns = model_names |
| 115 | + toc.index.name = "q" |
| 116 | + |
| 117 | + if normalize: |
| 118 | + toc = toc.div(np.abs(toc.iloc[-1, :]), axis=1) |
| 119 | + |
| 120 | + return toc |
| 121 | + |
| 122 | + |
| 123 | +def rate_score( |
| 124 | + df, |
| 125 | + outcome_col="y", |
| 126 | + treatment_col="w", |
| 127 | + treatment_effect_col="tau", |
| 128 | + weighting="autoc", |
| 129 | + normalize=False, |
| 130 | + random_seed=42, |
| 131 | +): |
| 132 | + """Calculate the Rank-weighted Average Treatment Effect (RATE) score. |
| 133 | +
|
| 134 | + RATE is the weighted area under the Targeting Operator Characteristic (TOC) curve: |
| 135 | +
|
| 136 | + RATE = integral_0^1 alpha(q) * TOC(q) dq |
| 137 | +
|
| 138 | + Two standard weighting schemes are supported (Yadlowsky et al., 2021): |
| 139 | +
|
| 140 | + - ``"autoc"``: alpha(q) = 1/q. Places more weight on the highest-priority units. |
| 141 | + Most powerful when treatment effects are concentrated in a small subgroup. |
| 142 | +
|
| 143 | + - ``"qini"``: alpha(q) = q. Uniform weighting across units; reduces to the Qini |
| 144 | + coefficient. More powerful when treatment effects are diffuse across the population. |
| 145 | +
|
| 146 | + A positive RATE indicates the prioritization rule effectively identifies units with |
| 147 | + above-average treatment benefit. A RATE near zero suggests little heterogeneity or |
| 148 | + a poor prioritization rule. |
| 149 | +
|
| 150 | + For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules |
| 151 | + via Rank-Weighted Average Treatment Effects`. https://arxiv.org/abs/2111.07966 |
| 152 | +
|
| 153 | + For the former, `treatment_effect_col` should be provided. For the latter, both |
| 154 | + `outcome_col` and `treatment_col` should be provided. |
| 155 | +
|
| 156 | + Args: |
| 157 | + df (pandas.DataFrame): a data frame with model estimates and actual data as columns |
| 158 | + outcome_col (str, optional): the column name for the actual outcome |
| 159 | + treatment_col (str, optional): the column name for the treatment indicator (0 or 1) |
| 160 | + treatment_effect_col (str, optional): the column name for the true treatment effect |
| 161 | + weighting (str, optional): the weighting scheme for the RATE integral. |
| 162 | + One of ``"autoc"`` (default) or ``"qini"``. |
| 163 | + normalize (bool, optional): whether to normalize the TOC curve before scoring |
| 164 | + random_seed (int, optional): deprecated |
| 165 | +
|
| 166 | + Returns: |
| 167 | + (pandas.Series): RATE scores of model estimates |
| 168 | + """ |
| 169 | + assert weighting in ( |
| 170 | + "autoc", |
| 171 | + "qini", |
| 172 | + ), "{} weighting is not implemented. Select one of {}".format( |
| 173 | + weighting, ("autoc", "qini") |
| 174 | + ) |
| 175 | + |
| 176 | + toc = get_toc( |
| 177 | + df, |
| 178 | + outcome_col=outcome_col, |
| 179 | + treatment_col=treatment_col, |
| 180 | + treatment_effect_col=treatment_effect_col, |
| 181 | + normalize=normalize, |
| 182 | + ) |
| 183 | + |
| 184 | + quantiles = toc.index.values # shape (n,), 0 included |
| 185 | + |
| 186 | + # Use midpoints to avoid division by zero for autoc at q=0 |
| 187 | + q_mid = (quantiles[:-1] + quantiles[1:]) / 2 |
| 188 | + toc_mid = (toc.iloc[:-1].values + toc.iloc[1:].values) / 2 |
| 189 | + |
| 190 | + if weighting == "autoc": |
| 191 | + weights = 1.0 / q_mid |
| 192 | + else: |
| 193 | + weights = q_mid |
| 194 | + |
| 195 | + # Normalize weights so they sum to 1 over the integration domain |
| 196 | + weights = weights / weights.sum() |
| 197 | + |
| 198 | + rate = pd.Series( |
| 199 | + np.average(toc_mid, axis=0, weights=weights), |
| 200 | + index=toc.columns, |
| 201 | + ) |
| 202 | + rate.name = "RATE ({})".format(weighting) |
| 203 | + return rate |
| 204 | + |
| 205 | + |
| 206 | +def plot_toc( |
| 207 | + df, |
| 208 | + outcome_col="y", |
| 209 | + treatment_col="w", |
| 210 | + treatment_effect_col="tau", |
| 211 | + normalize=False, |
| 212 | + random_seed=42, |
| 213 | + n=100, |
| 214 | + figsize=(8, 8), |
| 215 | + ax: Optional[plt.Axes] = None, |
| 216 | +) -> plt.Axes: |
| 217 | + """Plot the Targeting Operator Characteristic (TOC) curve of model estimates. |
| 218 | +
|
| 219 | + The TOC(q) shows the excess ATE when treating only the top-q fraction of units |
| 220 | + prioritized by a model score, relative to the overall ATE. A positive and steeply |
| 221 | + decreasing curve indicates the model effectively ranks high-benefit units first. |
| 222 | +
|
| 223 | + If the true treatment effect is provided (e.g. in synthetic data), it's used directly. |
| 224 | + Otherwise, it's estimated from observed outcomes and treatment assignments. |
| 225 | +
|
| 226 | + For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules |
| 227 | + via Rank-Weighted Average Treatment Effects`. https://arxiv.org/abs/2111.07966 |
| 228 | +
|
| 229 | + For the former, `treatment_effect_col` should be provided. For the latter, both |
| 230 | + `outcome_col` and `treatment_col` should be provided. |
| 231 | +
|
| 232 | + Args: |
| 233 | + df (pandas.DataFrame): a data frame with model estimates and actual data as columns |
| 234 | + outcome_col (str, optional): the column name for the actual outcome |
| 235 | + treatment_col (str, optional): the column name for the treatment indicator (0 or 1) |
| 236 | + treatment_effect_col (str, optional): the column name for the true treatment effect |
| 237 | + normalize (bool, optional): whether to normalize the y-axis to 1 or not |
| 238 | + random_seed (int, optional): deprecated |
| 239 | + n (int, optional): the number of samples to be used for plotting |
| 240 | + figsize (tuple, optional): the size of the figure to plot |
| 241 | + ax (plt.Axes, optional): an existing axes object to draw on |
| 242 | +
|
| 243 | + Returns: |
| 244 | + (plt.Axes): the matplotlib Axes with the TOC plot |
| 245 | + """ |
| 246 | + toc = get_toc( |
| 247 | + df, |
| 248 | + outcome_col=outcome_col, |
| 249 | + treatment_col=treatment_col, |
| 250 | + treatment_effect_col=treatment_effect_col, |
| 251 | + normalize=normalize, |
| 252 | + ) |
| 253 | + |
| 254 | + if (n is not None) and (n < toc.shape[0]): |
| 255 | + toc = toc.iloc[np.linspace(0, len(toc) - 1, n, endpoint=True).astype(int)] |
| 256 | + |
| 257 | + if ax is None: |
| 258 | + _, ax = plt.subplots(figsize=figsize) |
| 259 | + |
| 260 | + ax = toc.plot(ax=ax) |
| 261 | + |
| 262 | + # Random baseline (TOC = 0 everywhere) |
| 263 | + ax.plot( |
| 264 | + [toc.index[0], toc.index[-1]], |
| 265 | + [0, 0], |
| 266 | + label=RANDOM_COL, |
| 267 | + color="k", |
| 268 | + linestyle="--", |
| 269 | + ) |
| 270 | + ax.legend() |
| 271 | + ax.set_xlabel("Fraction treated (q)") |
| 272 | + ax.set_ylabel("TOC(q)") |
| 273 | + ax.set_title("Targeting Operator Characteristic (TOC)") |
| 274 | + |
| 275 | + return ax |
0 commit comments