@@ -17,20 +17,26 @@ def get_toc(
1717 treatment_col = "w" ,
1818 treatment_effect_col = "tau" ,
1919 normalize = False ,
20- random_seed = 42 ,
2120):
2221 """Get the Targeting Operator Characteristic (TOC) of model estimates in population.
2322
2423 TOC(q) is the difference between the ATE among the top-q fraction of units ranked
2524 by the prioritization score and the overall ATE. A positive TOC at low q indicates
2625 the model successfully identifies units with above-average treatment benefit.
2726
27+ By definition, TOC(0) = 0 and TOC(1) = 0 (the subset ATE equals the overall ATE
28+ when the entire population is selected).
29+
2830 If the true treatment effect is provided (e.g. in synthetic data), it's used directly
2931 to calculate TOC. Otherwise, it's estimated as the difference between the mean outcomes
3032 of the treatment and control groups in each quantile band.
3133
34+ Note: when using observed outcomes, if a quantile band contains only treated or only
35+ control units, the code falls back to TOC(q) = 0 for that band (i.e., subset ATE is
36+ set to the overall ATE). This is a conservative approximation and is logged as a warning.
37+
3238 For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules
33- via Rank-Weighted Average Treatment Effects`.
39+ via Rank-Weighted Average Treatment Effects`. https://arxiv.org/abs/2111.07966
3440
3541 For the former, `treatment_effect_col` should be provided. For the latter, both
3642 `outcome_col` and `treatment_col` should be provided.
@@ -40,8 +46,9 @@ def get_toc(
4046 outcome_col (str, optional): the column name for the actual outcome
4147 treatment_col (str, optional): the column name for the treatment indicator (0 or 1)
4248 treatment_effect_col (str, optional): the column name for the true treatment effect
43- normalize (bool, optional): whether to normalize the y-axis to 1 or not
44- random_seed (int, optional): deprecated
49+ normalize (bool, optional): whether to normalize the TOC curve by its maximum
50+ absolute value. Uses max(|TOC|) as the reference to avoid division by zero
51+ at q=1 where TOC is always zero by definition.
4552
4653 Returns:
4754 (pandas.DataFrame): TOC values of model estimates in population, indexed by quantile q
@@ -80,33 +87,46 @@ def get_toc(
8087 )
8188
8289 n_total = len (df )
83- quantiles = np .linspace (0 , 1 , n_total + 1 )[1 :]
8490
8591 toc = []
8692 for col in model_names :
8793 sorted_df = df .sort_values (col , ascending = False ).reset_index (drop = True )
88- toc_values = []
89-
90- for q in quantiles :
91- top_k = max (1 , int (np .ceil (q * n_total )))
92- top_subset = sorted_df .iloc [:top_k ]
9394
94- if use_oracle :
95- subset_ate = top_subset [treatment_effect_col ].mean ()
96- else :
97- t_mask = top_subset [treatment_col ] == 1
98- c_mask = top_subset [treatment_col ] == 0
99- if t_mask .sum () == 0 or c_mask .sum () == 0 :
100- subset_ate = overall_ate
101- else :
102- subset_ate = (
103- top_subset .loc [t_mask , outcome_col ].mean ()
104- - top_subset .loc [c_mask , outcome_col ].mean ()
105- )
106-
107- toc_values .append (subset_ate - overall_ate )
108-
109- toc .append (pd .Series (toc_values , index = quantiles ))
95+ if use_oracle :
96+ # O(n) via cumulative sum
97+ cumsum_tau = sorted_df [treatment_effect_col ].cumsum ().values
98+ counts = np .arange (1 , n_total + 1 )
99+ subset_ates = cumsum_tau / counts
100+ else :
101+ cumsum_tr = sorted_df [treatment_col ].cumsum ().values
102+ cumsum_ct = np .arange (1 , n_total + 1 ) - cumsum_tr
103+ cumsum_y_tr = (
104+ (sorted_df [outcome_col ] * sorted_df [treatment_col ]).cumsum ().values
105+ )
106+ cumsum_y_ct = (
107+ (sorted_df [outcome_col ] * (1 - sorted_df [treatment_col ]))
108+ .cumsum ()
109+ .values
110+ )
111+
112+ # Guard against division by zero when a band is all-treated or all-control;
113+ # fall back to overall_ate (TOC = 0) for those positions.
114+ with np .errstate (invalid = "ignore" , divide = "ignore" ):
115+ subset_ates = np .where (
116+ (cumsum_tr == 0 ) | (cumsum_ct == 0 ),
117+ overall_ate ,
118+ cumsum_y_tr / cumsum_tr - cumsum_y_ct / cumsum_ct ,
119+ )
120+
121+ if np .any ((cumsum_tr == 0 ) | (cumsum_ct == 0 )):
122+ logger .warning (
123+ "Some quantile bands contain only treated or only control units "
124+ "for column '%s'. TOC is set to 0 for those positions." ,
125+ col ,
126+ )
127+
128+ toc_values = subset_ates - overall_ate
129+ toc .append (pd .Series (toc_values , index = np .linspace (0 , 1 , n_total + 1 )[1 :]))
110130
111131 toc = pd .concat (toc , join = "inner" , axis = 1 )
112132 toc .loc [0 ] = np .zeros ((toc .shape [1 ],))
@@ -115,7 +135,11 @@ def get_toc(
115135 toc .index .name = "q"
116136
117137 if normalize :
118- toc = toc .div (np .abs (toc .iloc [- 1 , :]), axis = 1 )
138+ # Normalize by max absolute value rather than the value at q=1, which is
139+ # always zero by definition and would cause division by zero.
140+ max_abs = toc .abs ().max ()
141+ max_abs = max_abs .replace (0 , 1 ) # guard for flat TOC curves
142+ toc = toc .div (max_abs , axis = 1 )
119143
120144 return toc
121145
@@ -127,7 +151,6 @@ def rate_score(
127151 treatment_effect_col = "tau" ,
128152 weighting = "autoc" ,
129153 normalize = False ,
130- random_seed = 42 ,
131154):
132155 """Calculate the Rank-weighted Average Treatment Effect (RATE) score.
133156
@@ -147,6 +170,11 @@ def rate_score(
147170 above-average treatment benefit. A RATE near zero suggests little heterogeneity or
148171 a poor prioritization rule.
149172
173+ Note: the integral is approximated via a weighted mean over the discrete quantile grid
174+ using midpoint values. Weights are normalized to sum to 1 (i.e. ``weights / weights.sum()``),
175+ so the absolute scale matches the TOC values but may differ slightly from the paper's
176+ continuous integral definition. Model rankings are preserved.
177+
150178 For details, see Yadlowsky et al. (2021), `Evaluating Treatment Prioritization Rules
151179 via Rank-Weighted Average Treatment Effects`. https://arxiv.org/abs/2111.07966
152180
@@ -161,7 +189,6 @@ def rate_score(
161189 weighting (str, optional): the weighting scheme for the RATE integral.
162190 One of ``"autoc"`` (default) or ``"qini"``.
163191 normalize (bool, optional): whether to normalize the TOC curve before scoring
164- random_seed (int, optional): deprecated
165192
166193 Returns:
167194 (pandas.Series): RATE scores of model estimates
@@ -181,7 +208,7 @@ def rate_score(
181208 normalize = normalize ,
182209 )
183210
184- quantiles = toc .index .values # shape (n,), 0 included
211+ quantiles = toc .index .values # includes 0 and 1
185212
186213 # Use midpoints to avoid division by zero for autoc at q=0
187214 q_mid = (quantiles [:- 1 ] + quantiles [1 :]) / 2
@@ -209,7 +236,6 @@ def plot_toc(
209236 treatment_col = "w" ,
210237 treatment_effect_col = "tau" ,
211238 normalize = False ,
212- random_seed = 42 ,
213239 n = 100 ,
214240 figsize = (8 , 8 ),
215241 ax : Optional [plt .Axes ] = None ,
@@ -234,8 +260,8 @@ def plot_toc(
234260 outcome_col (str, optional): the column name for the actual outcome
235261 treatment_col (str, optional): the column name for the treatment indicator (0 or 1)
236262 treatment_effect_col (str, optional): the column name for the true treatment effect
237- normalize (bool, optional): whether to normalize the y-axis to 1 or not
238- random_seed (int, optional): deprecated
263+ normalize (bool, optional): whether to normalize the TOC curve by its maximum
264+ absolute value before plotting
239265 n (int, optional): the number of samples to be used for plotting
240266 figsize (tuple, optional): the size of the figure to plot
241267 ax (plt.Axes, optional): an existing axes object to draw on
0 commit comments