Skip to content

Commit 7d8fcbb

Browse files
authored
Merge pull request #1320 from stan-dev/stansummary-ess-per-second
Add ESS/s back to stansummary
2 parents b95fc4d + 30b65d0 commit 7d8fcbb

2 files changed

Lines changed: 56 additions & 43 deletions

File tree

src/cmdstan/stansummary_helper.hpp

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -366,28 +366,29 @@ stan::mcmc::chainset parse_csv_files(const std::vector<std::string> &filenames,
366366
*/
367367
std::vector<std::string> get_header(
368368
const std::vector<std::string> &percentiles) {
369-
// Mean, MCSE, StdDev, MAD, ... percentiles ..., ESS_bulk, ESS_tail, R_hat
370-
std::vector<std::string> header(percentiles.size() + 7);
371-
header.at(0) = "Mean";
372-
header.at(1) = "MCSE";
373-
header.at(2) = "StdDev";
374-
header.at(3) = "MAD";
375-
for (size_t i = 0; i < percentiles.size(); ++i) {
376-
header[i + 4] = percentiles[i] + '%';
369+
// Mean, MCSE, StdDev, MAD, ... percentiles ...,
370+
// ESS_bulk, ESS_tail, ESS_bulk/s, R_hat
371+
std::vector<std::string> header;
372+
header.emplace_back("Mean");
373+
header.emplace_back("MCSE");
374+
header.emplace_back("StdDev");
375+
header.emplace_back("MAD");
376+
for (auto per : percentiles) {
377+
header.push_back(per + '%');
377378
}
378-
size_t offset = 4 + percentiles.size();
379-
header.at(offset) = "ESS_bulk";
380-
header.at(offset + 1) = "ESS_tail";
381-
header.at(offset + 2) = "R_hat";
379+
header.emplace_back("ESS_bulk");
380+
header.emplace_back("ESS_tail");
381+
header.emplace_back("ESS_bulk/s");
382+
header.emplace_back("R_hat");
382383
return header;
383384
}
384385

385386
/**
386387
* Compute statistics for span of output columns
387-
* Mean, MCSE, StdDev, MAD, ... percentiles ..., ESS_bulk, ESS_tail, R_hat
388+
* Mean, MCSE, StdDev, MAD, ... percentiles ...,
389+
* ESS_bulk, ESS_tail, ESS_bulk/s, R_hat
388390
*
389391
* @param in set of samples from one or more chains
390-
* @param in vector of warmup times (required for N_eff/S)
391392
* @param in vector of sampling times (required for N_eff/S)
392393
* @param in vector of probabilities
393394
* @param in vector of model param column incides in chains object
@@ -404,6 +405,8 @@ void get_stats(const stan::mcmc::chainset &chains,
404405
throw std::domain_error("get_stats: size mismatch");
405406
}
406407

408+
double total_sampling_time = sampling_times.sum();
409+
407410
// Model parameters
408411
int i = 0;
409412
for (int i_chains : cols) {
@@ -415,14 +418,15 @@ void get_stats(const stan::mcmc::chainset &chains,
415418
for (int j = 0; j < quantiles.size(); j++)
416419
params(i, 4 + j) = quantiles(j);
417420

421+
auto offset = 4 + quantiles.size();
418422
auto [ess_bulk, ess_tail] = chains.split_rank_normalized_ess(i_chains);
419423

420-
params(i, quantiles.size() + 4) = ess_bulk;
421-
params(i, quantiles.size() + 5) = ess_tail;
424+
params(i, offset++) = ess_bulk;
425+
params(i, offset++) = ess_tail;
426+
params(i, offset++) = ess_bulk / total_sampling_time;
422427

423428
auto [rhat_bulk, rhat_tail] = chains.split_rank_normalized_rhat(i_chains);
424-
params(i, quantiles.size() + 6)
425-
= rhat_bulk > rhat_tail ? rhat_bulk : rhat_tail;
429+
params(i, offset++) = rhat_bulk > rhat_tail ? rhat_bulk : rhat_tail;
426430
i++;
427431
}
428432
}

src/test/interface/stansummary_test.cpp

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,10 @@ TEST(CommandStansummary, matrix_index_2d) {
137137
TEST(CommandStansummary, header_tests) {
138138
std::string expect
139139
= " Mean MCSE StdDev MAD 10% 50% 90%"
140-
" ESS_bulk ESS_tail R_hat\n";
140+
" ESS_bulk ESS_tail ESS_bulk/s R_hat\n";
141141
std::string expect_csv
142-
= "name,Mean,MCSE,StdDev,MAD,10%,50%,90%,ESS_bulk,ESS_tail,R_hat\n";
142+
= "name,Mean,MCSE,StdDev,MAD,10%,50%,90%,ESS_bulk,ESS_tail,ESS_bulk/"
143+
"s,R_hat\n";
143144
std::vector<std::string> pcts;
144145
pcts.push_back("10");
145146
pcts.push_back("50");
@@ -150,14 +151,15 @@ TEST(CommandStansummary, header_tests) {
150151
EXPECT_FLOAT_EQ(probs[2], 0.9);
151152

152153
std::vector<std::string> header = get_header(pcts);
153-
EXPECT_EQ(header.size(), pcts.size() + 7);
154+
EXPECT_EQ(header.size(), pcts.size() + 8);
154155
EXPECT_EQ(header[0], "Mean");
155156
EXPECT_EQ(header[2], "StdDev");
156157
EXPECT_EQ(header[3], "MAD");
157158
EXPECT_EQ(header[5], "50%");
158159
EXPECT_EQ(header[7], "ESS_bulk");
159160
EXPECT_EQ(header[8], "ESS_tail");
160-
EXPECT_EQ(header[9], "R_hat");
161+
EXPECT_EQ(header[9], "ESS_bulk/s");
162+
EXPECT_EQ(header[10], "R_hat");
161163

162164
Eigen::VectorXi column_widths(header.size());
163165
for (size_t i = 0, w = 5; i < header.size(); ++i, ++w) {
@@ -239,7 +241,7 @@ TEST(CommandStansummary, param_tests) {
239241
size_t num_model_params = chains.num_params() - model_params_offset;
240242
EXPECT_EQ(num_model_params, 1);
241243

242-
Eigen::MatrixXd model_params(num_model_params, 10);
244+
Eigen::MatrixXd model_params(num_model_params, 11);
243245
std::vector<int> model_param_idxes(num_model_params);
244246
std::iota(model_param_idxes.begin(), model_param_idxes.end(),
245247
model_params_offset);
@@ -248,7 +250,7 @@ TEST(CommandStansummary, param_tests) {
248250
double mean_theta = model_params(0, 0);
249251
EXPECT_TRUE(mean_theta > 0.25);
250252
EXPECT_TRUE(mean_theta < 0.27);
251-
double rhat_theta = model_params(0, 9);
253+
double rhat_theta = model_params(0, 10);
252254
EXPECT_TRUE(rhat_theta > 0.999);
253255
EXPECT_TRUE(rhat_theta < 1.01);
254256
}
@@ -426,16 +428,16 @@ TEST(CommandStansummary, bad_include_param_args) {
426428
TEST(CommandStansummary, check_console_output) {
427429
std::string lp
428430
= "lp__ -7.3 3.7e-02 0.77 0.30 -9.0 -7.0 -6.8 "
429-
" 519 503 1.0";
431+
" 519 503 22578 1.0";
430432
std::string theta
431433
= "theta 0.26 6.1e-03 0.12 0.12 0.080 0.25 0.47 "
432-
" 362 396 1.0";
434+
" 362 396 15718 1.0";
433435
std::string accept_stat
434436
= "accept_stat__ 0.90 4.6e-03 1.5e-01 0.064 0.57 0.96 1.0 "
435-
"1284 941 1.00";
437+
"1284 941 55805 1.00";
436438
std::string energy
437439
= "energy__ 7.8 5.1e-02 1.0e+00 0.75 6.8 7.5 9.9 "
438-
" 490 486 1.0";
440+
" 490 486 21299 1.0";
439441

440442
std::string path_separator;
441443
path_separator.push_back(get_path_separator());
@@ -480,16 +482,17 @@ TEST(CommandStansummary, check_console_output) {
480482

481483
TEST(CommandStansummary, check_csv_output) {
482484
std::string csv_header
483-
= "name,Mean,MCSE,StdDev,MAD,5%,50%,95%,ESS_bulk,ESS_tail,R_hat";
485+
= "name,Mean,MCSE,StdDev,MAD,5%,50%,95%,ESS_bulk,ESS_tail,ESS_bulk/"
486+
"s,R_hat";
484487
std::string lp
485488
= "\"lp__\",-7.2719,0.0365168,0.768874,0.303688,-8.98426,-6.97009,-6."
486-
"75007,519.29,503.309,1.00141";
489+
"75007,519.29,503.309,22577.8,1.00141";
487490
std::string energy
488491
= "\"energy__\",7.78428,0.0508815,1.0314,0.745859,6.80565,7.46758,9.8864,"
489-
"489.874,486.438,1.00495";
492+
"489.874,486.438,21298.9,1.00495";
490493
std::string theta
491494
= "\"theta\",0.256552,0.00610844,0.119654,0.120965,0.0802982,0.24996,0."
492-
"47034,361.506,395.736,1.00186";
495+
"47034,361.506,395.736,15717.6,1.00186";
493496

494497
std::string path_separator;
495498
path_separator.push_back(get_path_separator());
@@ -532,9 +535,11 @@ TEST(CommandStansummary, check_csv_output) {
532535
}
533536

534537
TEST(CommandStansummary, check_csv_output_no_percentiles) {
535-
std::string csv_header = "name,Mean,MCSE,StdDev,MAD,ESS_bulk,ESS_tail,R_hat";
538+
std::string csv_header
539+
= "name,Mean,MCSE,StdDev,MAD,ESS_bulk,ESS_tail,ESS_bulk/s,R_hat";
536540
std::string lp
537-
= "\"lp__\",-7.2719,0.0365168,0.768874,0.303688,519.29,503.309,1.00141";
541+
= "\"lp__\",-7.2719,0.0365168,0.768874,0.303688,519.29,503.309,22577.8,1."
542+
"00141";
538543

539544
std::string path_separator;
540545
path_separator.push_back(get_path_separator());
@@ -571,12 +576,15 @@ TEST(CommandStansummary, check_csv_output_no_percentiles) {
571576

572577
TEST(CommandStansummary, check_csv_output_sig_figs) {
573578
std::string csv_header
574-
= "name,Mean,MCSE,StdDev,MAD,5%,50%,95%,ESS_bulk,ESS_tail,R_hat";
575-
std::string lp = "\"lp__\",-7.3,0.037,0.77,0.3,-9,-7,-6.8,5.2e+02,5e+02,1";
579+
= "name,Mean,MCSE,StdDev,MAD,5%,50%,95%,ESS_bulk,ESS_tail,ESS_bulk/"
580+
"s,R_hat";
581+
std::string lp
582+
= "\"lp__\",-7.3,0.037,0.77,0.3,-9,-7,-6.8,5.2e+02,5e+02,2.3e+04,1";
576583
std::string energy
577-
= "\"energy__\",7.8,0.051,1,0.75,6.8,7.5,9.9,4.9e+02,4.9e+02,1";
584+
= "\"energy__\",7.8,0.051,1,0.75,6.8,7.5,9.9,4.9e+02,4.9e+02,2.1e+04,1";
578585
std::string theta
579-
= "\"theta\",0.26,0.0061,0.12,0.12,0.08,0.25,0.47,3.6e+02,4e+02,1";
586+
= "\"theta\",0.26,0.0061,0.12,0.12,0.08,0.25,0.47,3.6e+02,4e+02,1.6e+04,"
587+
"1";
580588

581589
std::string path_separator;
582590
path_separator.push_back(get_path_separator());
@@ -621,20 +629,21 @@ TEST(CommandStansummary, check_csv_output_sig_figs) {
621629

622630
TEST(CommandStansummary, check_csv_output_include_param) {
623631
std::string csv_header
624-
= "name,Mean,MCSE,StdDev,MAD,5%,50%,95%,ESS_bulk,ESS_tail,R_hat";
632+
= "name,Mean,MCSE,StdDev,MAD,5%,50%,95%,ESS_bulk,ESS_tail,ESS_bulk/"
633+
"s,R_hat";
625634
std::string lp
626635
= "\"lp__\",-15.5617,0.97319,6.05585,6.3817,-25.3182,-15.7598,-5.47732,"
627-
"41.1897,113.537,1.00153";
636+
"41.1897,113.537,396.283,1.00153";
628637
std::string energy
629638
= "\"energy__\",20.5888,1.01449,6.43127,6.6161,10.2809,20.8278,30.9921,"
630-
"42.5605,140.171,1.00069";
639+
"42.5605,140.171,409.472,1.00069";
631640
// note: skipping theta 1-5
632641
std::string theta6
633642
= "\"theta[6]\",5.001,0.365016,5.76072,5.37947,-4.95375,5.22746,14.1688,"
634-
"230.645,464.978,1.00054";
643+
"230.645,464.978,2219.02,1.00054";
635644
std::string theta7
636645
= "\"theta[7]\",8.54125,0.650098,6.22195,5.35785,-0.814388,8.09342,19."
637-
"2622,92.3075,241.177,1.00244";
646+
"2622,92.3075,241.177,888.084,1.00244";
638647
// note: skipping theta 8
639648
std::string message = "# Inference for Stan model: eight_schools_cp_model";
640649

0 commit comments

Comments
 (0)