Skip to content

Commit 1c555cf

Browse files
authored
Add Rank-weighted Average Treatment Effect (RATE) metric (#887)
* adding RATE metric for CATE evaluators * addressing suggested changes
1 parent c043043 commit 1c555cf

3 files changed

Lines changed: 463 additions & 0 deletions

File tree

causalml/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
get_tmleqini,
2626
) # noqa
2727
from .visualize import auuc_score, qini_score # noqa
28+
from .rate import get_toc, rate_score, plot_toc # noqa
2829
from .sensitivity import Sensitivity, SensitivityPlaceboTreatment # noqa
2930
from .sensitivity import (
3031
SensitivityRandomCause,

causalml/metrics/rate.py

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
from typing import Optional
2+
3+
import numpy as np
4+
import pandas as pd
5+
from matplotlib import pyplot as plt
6+
import logging
7+
8+
plt.style.use("fivethirtyeight")
9+
RANDOM_COL = "Random"
10+
11+
logger = logging.getLogger("causalml")
12+
13+
14+
def get_toc(
15+
df,
16+
outcome_col="y",
17+
treatment_col="w",
18+
treatment_effect_col="tau",
19+
normalize=False,
20+
):
21+
"""Get the Targeting Operator Characteristic (TOC) of model estimates in population.
22+
23+
TOC(q) is the difference between the ATE among the top-q fraction of units ranked
24+
by the prioritization score and the overall ATE. A positive TOC at low q indicates
25+
the model successfully identifies units with above-average treatment benefit.
26+
27+
By definition, TOC(0) = 0 and TOC(1) = 0 (the subset ATE equals the overall ATE
28+
when the entire population is selected).
29+
30+
If the true treatment effect is provided (e.g. in synthetic data), it's used directly
31+
to calculate TOC. Otherwise, it's estimated as the difference between the mean outcomes
32+
of the treatment and control groups in each quantile band.
33+
34+
Note: when using observed outcomes, if a quantile band contains only treated or only
35+
control units, the code falls back to TOC(q) = 0 for that band (i.e., subset ATE is
36+
set to the overall ATE). This is a conservative approximation and is logged as a warning.
37+
38+
For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules
39+
via Rank-Weighted Average Treatment Effects`. https://arxiv.org/abs/2111.07966
40+
41+
For the former, `treatment_effect_col` should be provided. For the latter, both
42+
`outcome_col` and `treatment_col` should be provided.
43+
44+
Args:
45+
df (pandas.DataFrame): a data frame with model estimates and actual data as columns
46+
outcome_col (str, optional): the column name for the actual outcome
47+
treatment_col (str, optional): the column name for the treatment indicator (0 or 1)
48+
treatment_effect_col (str, optional): the column name for the true treatment effect
49+
normalize (bool, optional): whether to normalize the TOC curve by its maximum
50+
absolute value. Uses max(|TOC|) as the reference to avoid division by zero
51+
at q=1 where TOC is always zero by definition.
52+
53+
Returns:
54+
(pandas.DataFrame): TOC values of model estimates in population, indexed by quantile q
55+
"""
56+
assert (
57+
(outcome_col in df.columns and df[outcome_col].notnull().all())
58+
and (treatment_col in df.columns and df[treatment_col].notnull().all())
59+
or (
60+
treatment_effect_col in df.columns
61+
and df[treatment_effect_col].notnull().all()
62+
)
63+
), "{outcome_col} and {treatment_col}, or {treatment_effect_col} should be present without null.".format(
64+
outcome_col=outcome_col,
65+
treatment_col=treatment_col,
66+
treatment_effect_col=treatment_effect_col,
67+
)
68+
69+
df = df.copy()
70+
71+
model_names = [
72+
x
73+
for x in df.columns
74+
if x not in [outcome_col, treatment_col, treatment_effect_col]
75+
]
76+
77+
use_oracle = (
78+
treatment_effect_col in df.columns and df[treatment_effect_col].notnull().all()
79+
)
80+
81+
if use_oracle:
82+
overall_ate = df[treatment_effect_col].mean()
83+
else:
84+
treated = df[treatment_col] == 1
85+
overall_ate = (
86+
df.loc[treated, outcome_col].mean() - df.loc[~treated, outcome_col].mean()
87+
)
88+
89+
n_total = len(df)
90+
91+
toc = []
92+
for col in model_names:
93+
sorted_df = df.sort_values(col, ascending=False).reset_index(drop=True)
94+
95+
if use_oracle:
96+
# O(n) via cumulative sum
97+
cumsum_tau = sorted_df[treatment_effect_col].cumsum().values
98+
counts = np.arange(1, n_total + 1)
99+
subset_ates = cumsum_tau / counts
100+
else:
101+
cumsum_tr = sorted_df[treatment_col].cumsum().values
102+
cumsum_ct = np.arange(1, n_total + 1) - cumsum_tr
103+
cumsum_y_tr = (
104+
(sorted_df[outcome_col] * sorted_df[treatment_col]).cumsum().values
105+
)
106+
cumsum_y_ct = (
107+
(sorted_df[outcome_col] * (1 - sorted_df[treatment_col]))
108+
.cumsum()
109+
.values
110+
)
111+
112+
# Guard against division by zero when a band is all-treated or all-control;
113+
# fall back to overall_ate (TOC = 0) for those positions.
114+
with np.errstate(invalid="ignore", divide="ignore"):
115+
subset_ates = np.where(
116+
(cumsum_tr == 0) | (cumsum_ct == 0),
117+
overall_ate,
118+
cumsum_y_tr / cumsum_tr - cumsum_y_ct / cumsum_ct,
119+
)
120+
121+
if np.any((cumsum_tr == 0) | (cumsum_ct == 0)):
122+
logger.warning(
123+
"Some quantile bands contain only treated or only control units "
124+
"for column '%s'. TOC is set to 0 for those positions.",
125+
col,
126+
)
127+
128+
toc_values = subset_ates - overall_ate
129+
toc.append(pd.Series(toc_values, index=np.linspace(0, 1, n_total + 1)[1:]))
130+
131+
toc = pd.concat(toc, join="inner", axis=1)
132+
toc.loc[0] = np.zeros((toc.shape[1],))
133+
toc = toc.sort_index().interpolate()
134+
toc.columns = model_names
135+
toc.index.name = "q"
136+
137+
if normalize:
138+
# Normalize by max absolute value rather than the value at q=1, which is
139+
# always zero by definition and would cause division by zero.
140+
max_abs = toc.abs().max()
141+
max_abs = max_abs.replace(0, 1) # guard for flat TOC curves
142+
toc = toc.div(max_abs, axis=1)
143+
144+
return toc
145+
146+
147+
def rate_score(
148+
df,
149+
outcome_col="y",
150+
treatment_col="w",
151+
treatment_effect_col="tau",
152+
weighting="autoc",
153+
normalize=False,
154+
):
155+
"""Calculate the Rank-weighted Average Treatment Effect (RATE) score.
156+
157+
RATE is the weighted area under the Targeting Operator Characteristic (TOC) curve:
158+
159+
RATE = integral_0^1 alpha(q) * TOC(q) dq
160+
161+
Two standard weighting schemes are supported (Yadlowsky et al., 2021):
162+
163+
- ``"autoc"``: alpha(q) = 1/q. Places more weight on the highest-priority units.
164+
Most powerful when treatment effects are concentrated in a small subgroup.
165+
166+
- ``"qini"``: alpha(q) = q. Uniform weighting across units; reduces to the Qini
167+
coefficient. More powerful when treatment effects are diffuse across the population.
168+
169+
A positive RATE indicates the prioritization rule effectively identifies units with
170+
above-average treatment benefit. A RATE near zero suggests little heterogeneity or
171+
a poor prioritization rule.
172+
173+
Note: the integral is approximated via a weighted mean over the discrete quantile grid
174+
using midpoint values. Weights are normalized to sum to 1 (i.e. ``weights / weights.sum()``),
175+
so the absolute scale matches the TOC values but may differ slightly from the paper's
176+
continuous integral definition. Model rankings are preserved.
177+
178+
For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules
179+
via Rank-Weighted Average Treatment Effects`. https://arxiv.org/abs/2111.07966
180+
181+
For the former, `treatment_effect_col` should be provided. For the latter, both
182+
`outcome_col` and `treatment_col` should be provided.
183+
184+
Args:
185+
df (pandas.DataFrame): a data frame with model estimates and actual data as columns
186+
outcome_col (str, optional): the column name for the actual outcome
187+
treatment_col (str, optional): the column name for the treatment indicator (0 or 1)
188+
treatment_effect_col (str, optional): the column name for the true treatment effect
189+
weighting (str, optional): the weighting scheme for the RATE integral.
190+
One of ``"autoc"`` (default) or ``"qini"``.
191+
normalize (bool, optional): whether to normalize the TOC curve before scoring
192+
193+
Returns:
194+
(pandas.Series): RATE scores of model estimates
195+
"""
196+
assert weighting in (
197+
"autoc",
198+
"qini",
199+
), "{} weighting is not implemented. Select one of {}".format(
200+
weighting, ("autoc", "qini")
201+
)
202+
203+
toc = get_toc(
204+
df,
205+
outcome_col=outcome_col,
206+
treatment_col=treatment_col,
207+
treatment_effect_col=treatment_effect_col,
208+
normalize=normalize,
209+
)
210+
211+
quantiles = toc.index.values # includes 0 and 1
212+
213+
# Use midpoints to avoid division by zero for autoc at q=0
214+
q_mid = (quantiles[:-1] + quantiles[1:]) / 2
215+
toc_mid = (toc.iloc[:-1].values + toc.iloc[1:].values) / 2
216+
217+
if weighting == "autoc":
218+
weights = 1.0 / q_mid
219+
else:
220+
weights = q_mid
221+
222+
# Normalize weights so they sum to 1 over the integration domain
223+
weights = weights / weights.sum()
224+
225+
rate = pd.Series(
226+
np.average(toc_mid, axis=0, weights=weights),
227+
index=toc.columns,
228+
)
229+
rate.name = "RATE ({})".format(weighting)
230+
return rate
231+
232+
233+
def plot_toc(
234+
df,
235+
outcome_col="y",
236+
treatment_col="w",
237+
treatment_effect_col="tau",
238+
normalize=False,
239+
n=100,
240+
figsize=(8, 8),
241+
ax: Optional[plt.Axes] = None,
242+
) -> plt.Axes:
243+
"""Plot the Targeting Operator Characteristic (TOC) curve of model estimates.
244+
245+
The TOC(q) shows the excess ATE when treating only the top-q fraction of units
246+
prioritized by a model score, relative to the overall ATE. A positive and steeply
247+
decreasing curve indicates the model effectively ranks high-benefit units first.
248+
249+
If the true treatment effect is provided (e.g. in synthetic data), it's used directly.
250+
Otherwise, it's estimated from observed outcomes and treatment assignments.
251+
252+
For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules
253+
via Rank-Weighted Average Treatment Effects`. https://arxiv.org/abs/2111.07966
254+
255+
For the former, `treatment_effect_col` should be provided. For the latter, both
256+
`outcome_col` and `treatment_col` should be provided.
257+
258+
Args:
259+
df (pandas.DataFrame): a data frame with model estimates and actual data as columns
260+
outcome_col (str, optional): the column name for the actual outcome
261+
treatment_col (str, optional): the column name for the treatment indicator (0 or 1)
262+
treatment_effect_col (str, optional): the column name for the true treatment effect
263+
normalize (bool, optional): whether to normalize the TOC curve by its maximum
264+
absolute value before plotting
265+
n (int, optional): the number of samples to be used for plotting
266+
figsize (tuple, optional): the size of the figure to plot
267+
ax (plt.Axes, optional): an existing axes object to draw on
268+
269+
Returns:
270+
(plt.Axes): the matplotlib Axes with the TOC plot
271+
"""
272+
toc = get_toc(
273+
df,
274+
outcome_col=outcome_col,
275+
treatment_col=treatment_col,
276+
treatment_effect_col=treatment_effect_col,
277+
normalize=normalize,
278+
)
279+
280+
if (n is not None) and (n < toc.shape[0]):
281+
toc = toc.iloc[np.linspace(0, len(toc) - 1, n, endpoint=True).astype(int)]
282+
283+
if ax is None:
284+
_, ax = plt.subplots(figsize=figsize)
285+
286+
ax = toc.plot(ax=ax)
287+
288+
# Random baseline (TOC = 0 everywhere)
289+
ax.plot(
290+
[toc.index[0], toc.index[-1]],
291+
[0, 0],
292+
label=RANDOM_COL,
293+
color="k",
294+
linestyle="--",
295+
)
296+
ax.legend()
297+
ax.set_xlabel("Fraction treated (q)")
298+
ax.set_ylabel("TOC(q)")
299+
ax.set_title("Targeting Operator Characteristic (TOC)")
300+
301+
return ax

0 commit comments

Comments
 (0)