Skip to content

Commit 30eaebd

Browse files
committed
addressing suggested changes
1 parent 4a3348d commit 30eaebd

2 files changed

Lines changed: 88 additions & 37 deletions

File tree

causalml/metrics/rate.py

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,26 @@ def get_toc(
1717
treatment_col="w",
1818
treatment_effect_col="tau",
1919
normalize=False,
20-
random_seed=42,
2120
):
2221
"""Get the Targeting Operator Characteristic (TOC) of model estimates in population.
2322
2423
TOC(q) is the difference between the ATE among the top-q fraction of units ranked
2524
by the prioritization score and the overall ATE. A positive TOC at low q indicates
2625
the model successfully identifies units with above-average treatment benefit.
2726
27+
By definition, TOC(0) = 0 and TOC(1) = 0 (the subset ATE equals the overall ATE
28+
when the entire population is selected).
29+
2830
If the true treatment effect is provided (e.g. in synthetic data), it's used directly
2931
to calculate TOC. Otherwise, it's estimated as the difference between the mean outcomes
3032
of the treatment and control groups in each quantile band.
3133
34+
Note: when using observed outcomes, if a quantile band contains only treated or only
35+
control units, the code falls back to TOC(q) = 0 for that band (i.e., subset ATE is
36+
set to the overall ATE). This is a conservative approximation and is logged as a warning.
37+
3238
For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules
33-
via Rank-Weighted Average Treatment Effects`.
39+
via Rank-Weighted Average Treatment Effects`. https://arxiv.org/abs/2111.07966
3440
3541
For the former, `treatment_effect_col` should be provided. For the latter, both
3642
`outcome_col` and `treatment_col` should be provided.
@@ -40,8 +46,9 @@ def get_toc(
4046
outcome_col (str, optional): the column name for the actual outcome
4147
treatment_col (str, optional): the column name for the treatment indicator (0 or 1)
4248
treatment_effect_col (str, optional): the column name for the true treatment effect
43-
normalize (bool, optional): whether to normalize the y-axis to 1 or not
44-
random_seed (int, optional): deprecated
49+
normalize (bool, optional): whether to normalize the TOC curve by its maximum
50+
absolute value. Uses max(|TOC|) as the reference to avoid division by zero
51+
at q=1 where TOC is always zero by definition.
4552
4653
Returns:
4754
(pandas.DataFrame): TOC values of model estimates in population, indexed by quantile q
@@ -80,33 +87,46 @@ def get_toc(
8087
)
8188

8289
n_total = len(df)
83-
quantiles = np.linspace(0, 1, n_total + 1)[1:]
8490

8591
toc = []
8692
for col in model_names:
8793
sorted_df = df.sort_values(col, ascending=False).reset_index(drop=True)
88-
toc_values = []
89-
90-
for q in quantiles:
91-
top_k = max(1, int(np.ceil(q * n_total)))
92-
top_subset = sorted_df.iloc[:top_k]
9394

94-
if use_oracle:
95-
subset_ate = top_subset[treatment_effect_col].mean()
96-
else:
97-
t_mask = top_subset[treatment_col] == 1
98-
c_mask = top_subset[treatment_col] == 0
99-
if t_mask.sum() == 0 or c_mask.sum() == 0:
100-
subset_ate = overall_ate
101-
else:
102-
subset_ate = (
103-
top_subset.loc[t_mask, outcome_col].mean()
104-
- top_subset.loc[c_mask, outcome_col].mean()
105-
)
106-
107-
toc_values.append(subset_ate - overall_ate)
108-
109-
toc.append(pd.Series(toc_values, index=quantiles))
95+
if use_oracle:
96+
# O(n) via cumulative sum
97+
cumsum_tau = sorted_df[treatment_effect_col].cumsum().values
98+
counts = np.arange(1, n_total + 1)
99+
subset_ates = cumsum_tau / counts
100+
else:
101+
cumsum_tr = sorted_df[treatment_col].cumsum().values
102+
cumsum_ct = np.arange(1, n_total + 1) - cumsum_tr
103+
cumsum_y_tr = (
104+
(sorted_df[outcome_col] * sorted_df[treatment_col]).cumsum().values
105+
)
106+
cumsum_y_ct = (
107+
(sorted_df[outcome_col] * (1 - sorted_df[treatment_col]))
108+
.cumsum()
109+
.values
110+
)
111+
112+
# Guard against division by zero when a band is all-treated or all-control;
113+
# fall back to overall_ate (TOC = 0) for those positions.
114+
with np.errstate(invalid="ignore", divide="ignore"):
115+
subset_ates = np.where(
116+
(cumsum_tr == 0) | (cumsum_ct == 0),
117+
overall_ate,
118+
cumsum_y_tr / cumsum_tr - cumsum_y_ct / cumsum_ct,
119+
)
120+
121+
if np.any((cumsum_tr == 0) | (cumsum_ct == 0)):
122+
logger.warning(
123+
"Some quantile bands contain only treated or only control units "
124+
"for column '%s'. TOC is set to 0 for those positions.",
125+
col,
126+
)
127+
128+
toc_values = subset_ates - overall_ate
129+
toc.append(pd.Series(toc_values, index=np.linspace(0, 1, n_total + 1)[1:]))
110130

111131
toc = pd.concat(toc, join="inner", axis=1)
112132
toc.loc[0] = np.zeros((toc.shape[1],))
@@ -115,7 +135,11 @@ def get_toc(
115135
toc.index.name = "q"
116136

117137
if normalize:
118-
toc = toc.div(np.abs(toc.iloc[-1, :]), axis=1)
138+
# Normalize by max absolute value rather than the value at q=1, which is
139+
# always zero by definition and would cause division by zero.
140+
max_abs = toc.abs().max()
141+
max_abs = max_abs.replace(0, 1) # guard for flat TOC curves
142+
toc = toc.div(max_abs, axis=1)
119143

120144
return toc
121145

@@ -127,7 +151,6 @@ def rate_score(
127151
treatment_effect_col="tau",
128152
weighting="autoc",
129153
normalize=False,
130-
random_seed=42,
131154
):
132155
"""Calculate the Rank-weighted Average Treatment Effect (RATE) score.
133156
@@ -147,6 +170,11 @@ def rate_score(
147170
above-average treatment benefit. A RATE near zero suggests little heterogeneity or
148171
a poor prioritization rule.
149172
173+
Note: the integral is approximated via a weighted mean over the discrete quantile grid
174+
using midpoint values. Weights are normalized to sum to 1 (i.e. ``weights / weights.sum()``),
175+
so the absolute scale matches the TOC values but may differ slightly from the paper's
176+
continuous integral definition. Model rankings are preserved.
177+
150178
For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules
151179
via Rank-Weighted Average Treatment Effects`. https://arxiv.org/abs/2111.07966
152180
@@ -161,7 +189,6 @@ def rate_score(
161189
weighting (str, optional): the weighting scheme for the RATE integral.
162190
One of ``"autoc"`` (default) or ``"qini"``.
163191
normalize (bool, optional): whether to normalize the TOC curve before scoring
164-
random_seed (int, optional): deprecated
165192
166193
Returns:
167194
(pandas.Series): RATE scores of model estimates
@@ -181,7 +208,7 @@ def rate_score(
181208
normalize=normalize,
182209
)
183210

184-
quantiles = toc.index.values # shape (n,), 0 included
211+
quantiles = toc.index.values # includes 0 and 1
185212

186213
# Use midpoints to avoid division by zero for autoc at q=0
187214
q_mid = (quantiles[:-1] + quantiles[1:]) / 2
@@ -209,7 +236,6 @@ def plot_toc(
209236
treatment_col="w",
210237
treatment_effect_col="tau",
211238
normalize=False,
212-
random_seed=42,
213239
n=100,
214240
figsize=(8, 8),
215241
ax: Optional[plt.Axes] = None,
@@ -234,8 +260,8 @@ def plot_toc(
234260
outcome_col (str, optional): the column name for the actual outcome
235261
treatment_col (str, optional): the column name for the treatment indicator (0 or 1)
236262
treatment_effect_col (str, optional): the column name for the true treatment effect
237-
normalize (bool, optional): whether to normalize the y-axis to 1 or not
238-
random_seed (int, optional): deprecated
263+
normalize (bool, optional): whether to normalize the TOC curve by its maximum
264+
absolute value before plotting
239265
n (int, optional): the number of samples to be used for plotting
240266
figsize (tuple, optional): the size of the figure to plot
241267
ax (plt.Axes, optional): an existing axes object to draw on

tests/test_rate.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,16 @@
55

66
from causalml.metrics.rate import get_toc, rate_score, plot_toc
77

8+
# Use project-wide constant instead of hardcoded seed
9+
try:
10+
from tests.const import RANDOM_SEED
11+
except ImportError:
12+
RANDOM_SEED = 42
13+
814

915
@pytest.fixture
1016
def synthetic_df():
11-
rng = np.random.default_rng(42)
17+
rng = np.random.default_rng(RANDOM_SEED)
1218
n = 500
1319
tau = rng.normal(1.0, 0.5, n)
1420
return pd.DataFrame(
@@ -24,7 +30,7 @@ def synthetic_df():
2430

2531
@pytest.fixture
2632
def rct_df():
27-
rng = np.random.default_rng(0)
33+
rng = np.random.default_rng(RANDOM_SEED)
2834
n = 2000
2935
tau = rng.uniform(0, 2, n)
3036
w = rng.integers(0, 2, n)
@@ -44,7 +50,7 @@ def test_get_toc_errors_on_nan():
4450
columns=["w", "y", "pred"],
4551
)
4652

47-
with pytest.raises(Exception):
53+
with pytest.raises(AssertionError):
4854
get_toc(df)
4955

5056

@@ -64,6 +70,12 @@ def test_get_toc_starts_at_zero(synthetic_df):
6470
assert np.allclose(toc.iloc[0].values, 0.0)
6571

6672

73+
def test_get_toc_ends_at_zero(synthetic_df):
74+
# TOC(1) = 0 by definition: subset ATE over full population equals overall ATE
75+
toc = get_toc(synthetic_df, treatment_effect_col="tau")
76+
assert np.allclose(toc.iloc[-1].values, 0.0, atol=1e-10)
77+
78+
6779
def test_get_toc_model_cols_present(synthetic_df):
6880
toc = get_toc(synthetic_df, treatment_effect_col="tau")
6981
assert "perfect_model" in toc.columns
@@ -75,6 +87,19 @@ def test_get_toc_perfect_model_positive_at_low_q(synthetic_df):
7587
assert toc.loc[toc.index <= 0.15, "perfect_model"].mean() > 0
7688

7789

90+
def test_get_toc_normalize_no_inf_or_nan(synthetic_df):
91+
# normalize=True previously divided by TOC(1) which is always zero —
92+
# now uses max(|TOC|) to avoid inf/NaN
93+
toc = get_toc(synthetic_df, treatment_effect_col="tau", normalize=True)
94+
assert not toc.isnull().any().any()
95+
assert np.isfinite(toc.values).all()
96+
97+
98+
def test_get_toc_normalize_max_abs_is_one(synthetic_df):
99+
toc = get_toc(synthetic_df, treatment_effect_col="tau", normalize=True)
100+
assert np.allclose(toc.abs().max(), 1.0)
101+
102+
78103
def test_get_toc_observed_outcomes(rct_df):
79104
toc = get_toc(rct_df, outcome_col="y", treatment_col="w")
80105
assert isinstance(toc, pd.DataFrame)
@@ -107,7 +132,7 @@ def test_rate_score_invalid_weighting_raises(synthetic_df):
107132

108133

109134
def test_rate_score_constant_tau_near_zero():
110-
rng = np.random.default_rng(99)
135+
rng = np.random.default_rng(RANDOM_SEED)
111136
n = 2000
112137
df = pd.DataFrame(
113138
{

0 commit comments

Comments
 (0)