Skip to content

Commit f0790df

Browse files
authored
Merge pull request #1317 from stan-dev/cleanup-log-prob-method
Stop passing raw stream to log_prob_grad helper
2 parents 45b0e9d + cdd4177 commit f0790df

2 files changed

Lines changed: 11 additions & 18 deletions

File tree

src/cmdstan/command.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,7 @@ int command(int argc, const char *argv[]) {
449449
}
450450
}
451451
try {
452-
services_log_prob_grad(model, jacobian, params_r_ind, sig_figs,
453-
sample_writers[0].get_stream());
452+
services_log_prob_grad(model, jacobian, params_r_ind, sample_writers[0]);
454453
return_code = return_codes::OK;
455454
} catch (const std::exception &e) {
456455
msg << "Error during log_prob calculation:" << std::endl;

src/cmdstan/command_helper.hpp

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <cmdstan/arguments/argument_parser.hpp>
55
#include <cmdstan/arguments/arg_sample.hpp>
66
#include <cmdstan/file.hpp>
7-
#include <stan/callbacks/unique_stream_writer.hpp>
87
#include <stan/callbacks/json_writer.hpp>
98
#include <stan/callbacks/writer.hpp>
109
#include <stan/io/dump.hpp>
@@ -580,18 +579,14 @@ std::vector<std::vector<double>> get_uparams_r(
580579
*/
581580
void services_log_prob_grad(const stan::model::model_base &model, bool jacobian,
582581
std::vector<std::vector<double>> &params_set,
583-
int sig_figs, std::ostream &output_stream) {
584-
// header row
585-
output_stream << std::setprecision(sig_figs) << "lp__,";
586-
std::vector<std::string> p_names;
582+
stan::callbacks::writer &output) {
583+
// header
584+
std::vector<std::string> p_names{"lp__"};
587585
model.unconstrained_param_names(p_names, false, false);
588-
for (size_t i = 0; i < p_names.size(); ++i) {
589-
output_stream << "g_" << p_names[i];
590-
if (i == p_names.size() - 1)
591-
output_stream << "\n";
592-
else
593-
output_stream << ",";
594-
}
586+
std::transform(p_names.begin() + 1, p_names.end(), p_names.begin() + 1,
587+
[](std::string s) { return "g_" + s; });
588+
output(p_names);
589+
595590
// data row(s)
596591
std::vector<int> dummy_params_i;
597592
double lp;
@@ -604,10 +599,9 @@ void services_log_prob_grad(const stan::model::model_base &model, bool jacobian,
604599
lp = stan::model::log_prob_grad<true, false>(model, params,
605600
dummy_params_i, gradients);
606601
}
607-
output_stream << lp << ",";
608-
std::copy(gradients.begin(), gradients.end() - 1,
609-
std::ostream_iterator<double>(output_stream, ","));
610-
output_stream << gradients.back() << "\n";
602+
// unfortunate: var.grad clears the vector, so need to insert lp afterwards
603+
gradients.insert(gradients.begin(), lp);
604+
output(gradients);
611605
}
612606
}
613607

0 commit comments

Comments
 (0)