Skip to content

Commit 4a3348d

Browse files
committed
adding RATE metric for CATE evaluators
1 parent 123cc1f commit 4a3348d

3 files changed

Lines changed: 412 additions & 0 deletions

File tree

causalml/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
get_tmleqini,
2626
) # noqa
2727
from .visualize import auuc_score, qini_score # noqa
28+
from .rate import get_toc, rate_score, plot_toc # noqa
2829
from .sensitivity import Sensitivity, SensitivityPlaceboTreatment # noqa
2930
from .sensitivity import (
3031
SensitivityRandomCause,

causalml/metrics/rate.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
from typing import Optional
2+
3+
import numpy as np
4+
import pandas as pd
5+
from matplotlib import pyplot as plt
6+
import logging
7+
8+
plt.style.use("fivethirtyeight")
9+
RANDOM_COL = "Random"
10+
11+
logger = logging.getLogger("causalml")
12+
13+
14+
def get_toc(
15+
df,
16+
outcome_col="y",
17+
treatment_col="w",
18+
treatment_effect_col="tau",
19+
normalize=False,
20+
random_seed=42,
21+
):
22+
"""Get the Targeting Operator Characteristic (TOC) of model estimates in population.
23+
24+
TOC(q) is the difference between the ATE among the top-q fraction of units ranked
25+
by the prioritization score and the overall ATE. A positive TOC at low q indicates
26+
the model successfully identifies units with above-average treatment benefit.
27+
28+
If the true treatment effect is provided (e.g. in synthetic data), it's used directly
29+
to calculate TOC. Otherwise, it's estimated as the difference between the mean outcomes
30+
of the treatment and control groups in each quantile band.
31+
32+
For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules
33+
via Rank-Weighted Average Treatment Effects`.
34+
35+
For the former, `treatment_effect_col` should be provided. For the latter, both
36+
`outcome_col` and `treatment_col` should be provided.
37+
38+
Args:
39+
df (pandas.DataFrame): a data frame with model estimates and actual data as columns
40+
outcome_col (str, optional): the column name for the actual outcome
41+
treatment_col (str, optional): the column name for the treatment indicator (0 or 1)
42+
treatment_effect_col (str, optional): the column name for the true treatment effect
43+
normalize (bool, optional): whether to normalize the y-axis to 1 or not
44+
random_seed (int, optional): deprecated
45+
46+
Returns:
47+
(pandas.DataFrame): TOC values of model estimates in population, indexed by quantile q
48+
"""
49+
assert (
50+
(outcome_col in df.columns and df[outcome_col].notnull().all())
51+
and (treatment_col in df.columns and df[treatment_col].notnull().all())
52+
or (
53+
treatment_effect_col in df.columns
54+
and df[treatment_effect_col].notnull().all()
55+
)
56+
), "{outcome_col} and {treatment_col}, or {treatment_effect_col} should be present without null.".format(
57+
outcome_col=outcome_col,
58+
treatment_col=treatment_col,
59+
treatment_effect_col=treatment_effect_col,
60+
)
61+
62+
df = df.copy()
63+
64+
model_names = [
65+
x
66+
for x in df.columns
67+
if x not in [outcome_col, treatment_col, treatment_effect_col]
68+
]
69+
70+
use_oracle = (
71+
treatment_effect_col in df.columns and df[treatment_effect_col].notnull().all()
72+
)
73+
74+
if use_oracle:
75+
overall_ate = df[treatment_effect_col].mean()
76+
else:
77+
treated = df[treatment_col] == 1
78+
overall_ate = (
79+
df.loc[treated, outcome_col].mean() - df.loc[~treated, outcome_col].mean()
80+
)
81+
82+
n_total = len(df)
83+
quantiles = np.linspace(0, 1, n_total + 1)[1:]
84+
85+
toc = []
86+
for col in model_names:
87+
sorted_df = df.sort_values(col, ascending=False).reset_index(drop=True)
88+
toc_values = []
89+
90+
for q in quantiles:
91+
top_k = max(1, int(np.ceil(q * n_total)))
92+
top_subset = sorted_df.iloc[:top_k]
93+
94+
if use_oracle:
95+
subset_ate = top_subset[treatment_effect_col].mean()
96+
else:
97+
t_mask = top_subset[treatment_col] == 1
98+
c_mask = top_subset[treatment_col] == 0
99+
if t_mask.sum() == 0 or c_mask.sum() == 0:
100+
subset_ate = overall_ate
101+
else:
102+
subset_ate = (
103+
top_subset.loc[t_mask, outcome_col].mean()
104+
- top_subset.loc[c_mask, outcome_col].mean()
105+
)
106+
107+
toc_values.append(subset_ate - overall_ate)
108+
109+
toc.append(pd.Series(toc_values, index=quantiles))
110+
111+
toc = pd.concat(toc, join="inner", axis=1)
112+
toc.loc[0] = np.zeros((toc.shape[1],))
113+
toc = toc.sort_index().interpolate()
114+
toc.columns = model_names
115+
toc.index.name = "q"
116+
117+
if normalize:
118+
toc = toc.div(np.abs(toc.iloc[-1, :]), axis=1)
119+
120+
return toc
121+
122+
123+
def rate_score(
124+
df,
125+
outcome_col="y",
126+
treatment_col="w",
127+
treatment_effect_col="tau",
128+
weighting="autoc",
129+
normalize=False,
130+
random_seed=42,
131+
):
132+
"""Calculate the Rank-weighted Average Treatment Effect (RATE) score.
133+
134+
RATE is the weighted area under the Targeting Operator Characteristic (TOC) curve:
135+
136+
RATE = integral_0^1 alpha(q) * TOC(q) dq
137+
138+
Two standard weighting schemes are supported (Yadlowsky et al., 2021):
139+
140+
- ``"autoc"``: alpha(q) = 1/q. Places more weight on the highest-priority units.
141+
Most powerful when treatment effects are concentrated in a small subgroup.
142+
143+
- ``"qini"``: alpha(q) = q. Uniform weighting across units; reduces to the Qini
144+
coefficient. More powerful when treatment effects are diffuse across the population.
145+
146+
A positive RATE indicates the prioritization rule effectively identifies units with
147+
above-average treatment benefit. A RATE near zero suggests little heterogeneity or
148+
a poor prioritization rule.
149+
150+
For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules
151+
via Rank-Weighted Average Treatment Effects`. https://arxiv.org/abs/2111.07966
152+
153+
For the former, `treatment_effect_col` should be provided. For the latter, both
154+
`outcome_col` and `treatment_col` should be provided.
155+
156+
Args:
157+
df (pandas.DataFrame): a data frame with model estimates and actual data as columns
158+
outcome_col (str, optional): the column name for the actual outcome
159+
treatment_col (str, optional): the column name for the treatment indicator (0 or 1)
160+
treatment_effect_col (str, optional): the column name for the true treatment effect
161+
weighting (str, optional): the weighting scheme for the RATE integral.
162+
One of ``"autoc"`` (default) or ``"qini"``.
163+
normalize (bool, optional): whether to normalize the TOC curve before scoring
164+
random_seed (int, optional): deprecated
165+
166+
Returns:
167+
(pandas.Series): RATE scores of model estimates
168+
"""
169+
assert weighting in (
170+
"autoc",
171+
"qini",
172+
), "{} weighting is not implemented. Select one of {}".format(
173+
weighting, ("autoc", "qini")
174+
)
175+
176+
toc = get_toc(
177+
df,
178+
outcome_col=outcome_col,
179+
treatment_col=treatment_col,
180+
treatment_effect_col=treatment_effect_col,
181+
normalize=normalize,
182+
)
183+
184+
quantiles = toc.index.values # shape (n,), 0 included
185+
186+
# Use midpoints to avoid division by zero for autoc at q=0
187+
q_mid = (quantiles[:-1] + quantiles[1:]) / 2
188+
toc_mid = (toc.iloc[:-1].values + toc.iloc[1:].values) / 2
189+
190+
if weighting == "autoc":
191+
weights = 1.0 / q_mid
192+
else:
193+
weights = q_mid
194+
195+
# Normalize weights so they sum to 1 over the integration domain
196+
weights = weights / weights.sum()
197+
198+
rate = pd.Series(
199+
np.average(toc_mid, axis=0, weights=weights),
200+
index=toc.columns,
201+
)
202+
rate.name = "RATE ({})".format(weighting)
203+
return rate
204+
205+
206+
def plot_toc(
207+
df,
208+
outcome_col="y",
209+
treatment_col="w",
210+
treatment_effect_col="tau",
211+
normalize=False,
212+
random_seed=42,
213+
n=100,
214+
figsize=(8, 8),
215+
ax: Optional[plt.Axes] = None,
216+
) -> plt.Axes:
217+
"""Plot the Targeting Operator Characteristic (TOC) curve of model estimates.
218+
219+
The TOC(q) shows the excess ATE when treating only the top-q fraction of units
220+
prioritized by a model score, relative to the overall ATE. A positive and steeply
221+
decreasing curve indicates the model effectively ranks high-benefit units first.
222+
223+
If the true treatment effect is provided (e.g. in synthetic data), it's used directly.
224+
Otherwise, it's estimated from observed outcomes and treatment assignments.
225+
226+
For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules
227+
via Rank-Weighted Average Treatment Effects`. https://arxiv.org/abs/2111.07966
228+
229+
For the former, `treatment_effect_col` should be provided. For the latter, both
230+
`outcome_col` and `treatment_col` should be provided.
231+
232+
Args:
233+
df (pandas.DataFrame): a data frame with model estimates and actual data as columns
234+
outcome_col (str, optional): the column name for the actual outcome
235+
treatment_col (str, optional): the column name for the treatment indicator (0 or 1)
236+
treatment_effect_col (str, optional): the column name for the true treatment effect
237+
normalize (bool, optional): whether to normalize the y-axis to 1 or not
238+
random_seed (int, optional): deprecated
239+
n (int, optional): the number of samples to be used for plotting
240+
figsize (tuple, optional): the size of the figure to plot
241+
ax (plt.Axes, optional): an existing axes object to draw on
242+
243+
Returns:
244+
(plt.Axes): the matplotlib Axes with the TOC plot
245+
"""
246+
toc = get_toc(
247+
df,
248+
outcome_col=outcome_col,
249+
treatment_col=treatment_col,
250+
treatment_effect_col=treatment_effect_col,
251+
normalize=normalize,
252+
)
253+
254+
if (n is not None) and (n < toc.shape[0]):
255+
toc = toc.iloc[np.linspace(0, len(toc) - 1, n, endpoint=True).astype(int)]
256+
257+
if ax is None:
258+
_, ax = plt.subplots(figsize=figsize)
259+
260+
ax = toc.plot(ax=ax)
261+
262+
# Random baseline (TOC = 0 everywhere)
263+
ax.plot(
264+
[toc.index[0], toc.index[-1]],
265+
[0, 0],
266+
label=RANDOM_COL,
267+
color="k",
268+
linestyle="--",
269+
)
270+
ax.legend()
271+
ax.set_xlabel("Fraction treated (q)")
272+
ax.set_ylabel("TOC(q)")
273+
ax.set_title("Targeting Operator Characteristic (TOC)")
274+
275+
return ax

0 commit comments

Comments
 (0)