Skip to content

Commit 552737c

Browse files
committed
fix laplace_marginal optimization
1 parent 8ee3d96 commit 552737c

2 files changed

Lines changed: 109 additions & 10 deletions

File tree

stan/math/mix/functor/laplace_marginal_density_estimator.hpp

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,31 @@ struct NewtonState {
380380
wolfe_status.num_backtracks_ = -1; // Safe initial value for BB step
381381
}
382382

383+
/**
384+
* @brief Constructs Newton state with a consistent (a_init, theta_init) pair.
385+
*
386+
* When the caller supplies a non-zero theta_init, a_init = Sigma^{-1} *
387+
* theta_init must be provided to maintain the invariant theta = Sigma * a.
388+
*
389+
* @param theta_size Dimension of the latent space
390+
* @param obj_fun Objective function: (a, theta) -> double
391+
* @param theta_grad_f Gradient function: theta -> grad
392+
* @param a_init Initial a value consistent with theta_init
393+
* @param theta_init Initial theta value
394+
*/
395+
template <typename ObjFun, typename ThetaGradFun, typename ThetaInitializer>
396+
NewtonState(int theta_size, ObjFun&& obj_fun, ThetaGradFun&& theta_grad_f,
397+
const Eigen::VectorXd& a_init,
398+
ThetaInitializer&& theta_init)
399+
: wolfe_info(std::forward<ObjFun>(obj_fun), a_init,
400+
std::forward<ThetaInitializer>(theta_init),
401+
std::forward<ThetaGradFun>(theta_grad_f), 0),
402+
b(theta_size),
403+
B(theta_size, theta_size),
404+
prev_g(theta_size) {
405+
wolfe_status.num_backtracks_ = -1; // Safe initial value for BB step
406+
}
407+
383408
/**
384409
* @brief Access the current step state (mutable).
385410
* @return Reference to current WolfeStep
@@ -426,9 +451,13 @@ inline void llt_with_jitter(LLT& llt_B, B_t& B, double min_jitter = 1e-10,
426451
double max_jitter = 1e-5) {
427452
llt_B.compute(B);
428453
if (llt_B.info() != Eigen::Success) {
454+
double prev_jitter = 0.0;
429455
double jitter_try = min_jitter;
430456
for (; jitter_try < max_jitter; jitter_try *= 10) {
431-
B.diagonal().array() += jitter_try;
457+
// Remove previously added jitter before adding the new (larger) amount,
458+
// so that the total diagonal perturbation is exactly jitter_try.
459+
B.diagonal().array() += (jitter_try - prev_jitter);
460+
prev_jitter = jitter_try;
432461
llt_B.compute(B);
433462
if (llt_B.info() == Eigen::Success) {
434463
break;
@@ -935,6 +964,9 @@ inline auto run_newton_loop(SolverPolicy& solver, NewtonStateT& state,
935964
scratch.alpha() = 1.0;
936965
update_fun(scratch, state.curr(), state.prev(), scratch.eval_,
937966
state.wolfe_info.p_);
967+
// Save the full Newton step objective before the Wolfe line search
968+
// overwrites scratch with intermediate trial points.
969+
const double full_newton_obj = scratch.eval_.obj();
938970
if (scratch.alpha() <= options.line_search.min_alpha) {
939971
state.wolfe_status.accept_ = false;
940972
finish_update = true;
@@ -953,15 +985,42 @@ inline auto run_newton_loop(SolverPolicy& solver, NewtonStateT& state,
953985
state.wolfe_status = internal::wolfe_line_search(
954986
state.wolfe_info, update_fun, options.line_search, msgs);
955987
}
988+
// When the Wolfe line search rejects, don't immediately terminate.
989+
// Instead, let the Newton loop try at least one more iteration.
990+
// The original code compared the stale curr.obj() (which equalled
991+
// prev.obj() after the swap in update_next_step) and would always
992+
// terminate on ANY Wolfe rejection — even on the very first Newton
993+
// step. Now we only declare search_failed if the full Newton step
994+
// itself didn't improve the objective.
995+
bool search_failed;
996+
if (!state.wolfe_status.accept_) {
997+
if (full_newton_obj > state.prev().obj()) {
998+
// The full Newton step (evaluated before Wolfe ran) improved
999+
// the objective. Re-evaluate scratch at the full Newton step
1000+
// so we can accept it as the current iterate.
1001+
scratch.eval_.alpha() = 1.0;
1002+
update_fun(scratch, state.curr(), state.prev(), scratch.eval_,
1003+
state.wolfe_info.p_);
1004+
state.curr().update(scratch);
1005+
state.wolfe_status.accept_ = true;
1006+
search_failed = false;
1007+
} else {
1008+
search_failed = true;
1009+
}
1010+
} else {
1011+
search_failed = false;
1012+
}
9561013
/**
957-
* Stop when objective change is small, or when a rejected Wolfe step
958-
* fails to improve; finish_update then exits the Newton loop.
1014+
* Stop when objective change is small (absolute AND relative), or when
1015+
* a rejected Wolfe step fails to improve; finish_update then exits the
1016+
* Newton loop.
9591017
*/
1018+
double obj_change = std::abs(state.curr().obj() - state.prev().obj());
9601019
bool objective_converged
961-
= std::abs(state.curr().obj() - state.prev().obj())
962-
< options.tolerance;
963-
bool search_failed = (!state.wolfe_status.accept_
964-
&& state.curr().obj() <= state.prev().obj());
1020+
= obj_change < options.tolerance
1021+
&& obj_change
1022+
< options.tolerance
1023+
* std::abs(state.prev().obj());
9651024
finish_update = objective_converged || search_failed;
9661025
}
9671026
if (finish_update) {
@@ -1152,7 +1211,23 @@ inline auto laplace_marginal_density_est(
11521211
return laplace_likelihood::theta_grad(ll_fun, theta_val, ll_args, msgs);
11531212
};
11541213
decltype(auto) theta_init = theta_init_impl<InitTheta>(theta_size, options);
1155-
internal::NewtonState state(theta_size, obj_fun, theta_grad_f, theta_init);
1214+
// When the user supplies a non-zero theta_init, we must initialise a
1215+
// consistently so that the invariant theta = Sigma * a holds. Otherwise
1216+
// the prior term -0.5 * a'*theta vanishes (a=0 while theta!=0), inflating
1217+
// the initial objective and causing the Wolfe line search to reject the
1218+
// first Newton step.
1219+
auto make_state = [&](auto&& theta_0) {
1220+
if constexpr (InitTheta) {
1221+
Eigen::VectorXd a_init = covariance.llt().solve(
1222+
Eigen::VectorXd(theta_0));
1223+
return internal::NewtonState(theta_size, obj_fun, theta_grad_f,
1224+
a_init, theta_0);
1225+
} else {
1226+
return internal::NewtonState(theta_size, obj_fun, theta_grad_f,
1227+
theta_0);
1228+
}
1229+
};
1230+
auto state = make_state(theta_init);
11561231
// Start with safe step size
11571232
auto update_fun = create_update_fun(
11581233
std::move(obj_fun), std::move(theta_grad_f), covariance, options);

stan/math/mix/functor/wolfe_line_search.hpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,29 @@ struct WolfeInfo {
512512
"theta and likelihood arguments.");
513513
}
514514
}
515+
/**
516+
* Construct WolfeInfo with a consistent (a_init, theta_init) pair.
517+
*
518+
* When the caller supplies a non-zero theta_init, the corresponding
519+
* a_init = Sigma^{-1} * theta_init must be provided so that the
520+
* invariant theta = Sigma * a holds at initialization. This avoids
521+
* an inflated initial objective (the prior term -0.5 * a'*theta would
522+
* otherwise vanish when a is zero but theta is not).
523+
*/
524+
template <typename ObjFun, typename Theta0, typename ThetaGradF>
525+
WolfeInfo(ObjFun&& obj_fun, const Eigen::VectorXd& a_init, Theta0&& theta0,
526+
ThetaGradF&& theta_grad_f, int /*tag*/)
527+
: curr_(std::forward<ObjFun>(obj_fun), a_init,
528+
std::forward<Theta0>(theta0),
529+
std::forward<ThetaGradF>(theta_grad_f)),
530+
prev_(curr_),
531+
scratch_(a_init.size()) {
532+
if (!std::isfinite(curr_.obj())) {
533+
throw std::domain_error(
534+
"laplace_marginal_density: log likelihood is not finite at initial "
535+
"theta and likelihood arguments.");
536+
}
537+
}
515538
WolfeInfo(WolfeData&& curr, WolfeData&& prev)
516539
: curr_(std::move(curr)),
517540
prev_(std::move(prev)),
@@ -902,9 +925,10 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
902925
} else { // [3]
903926
high = mid;
904927
}
928+
} else {
929+
// [4]
930+
high = mid;
905931
}
906-
// [4]
907-
high = mid;
908932
} else {
909933
// [5]
910934
high = mid;

0 commit comments

Comments
 (0)