Skip to content

Commit 4d48730

Browse files
authored
Merge pull request #2820 from stan-dev/feature/issue-2828-algebra_solver_variadic
Feature/issue 2828 algebra solver variadic
2 parents 39a2855 + a88e494 commit 4d48730

7 files changed

Lines changed: 434 additions & 254 deletions

File tree

make/tests

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ $(IDAS_TESTS) : $(LIBSUNDIALS)
8282
# KINSOL tests
8383
##
8484

85-
ALGEBRA_SOLVER_TESTS := $(subst .cpp,$(EXE),$(call findfiles,test,*algebra_solver*_test.cpp))
85+
ALGEBRA_SOLVER_TESTS := $(subst .cpp,$(EXE),$(call findfiles,test,*solve*_test.cpp))
8686
$(ALGEBRA_SOLVER_TESTS) : $(LIBSUNDIALS)
8787

8888
### These can be generated by the jumbo tests and include the above kinds

stan/math/rev/functor.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
#define STAN_MATH_REV_FUNCTOR_HPP
33

44
#include <stan/math/rev/functor/algebra_solver_fp.hpp>
5-
#include <stan/math/rev/functor/algebra_solver_powell.hpp>
6-
#include <stan/math/rev/functor/algebra_solver_newton.hpp>
5+
#include <stan/math/rev/functor/solve_powell.hpp>
6+
#include <stan/math/rev/functor/solve_newton.hpp>
77
#include <stan/math/rev/functor/algebra_system.hpp>
88
#include <stan/math/rev/functor/apply_scalar_unary.hpp>
99
#include <stan/math/rev/functor/apply_scalar_binary.hpp>

stan/math/rev/functor/algebra_solver_newton.hpp renamed to stan/math/rev/functor/solve_newton.hpp

Lines changed: 75 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef STAN_MATH_REV_FUNCTOR_ALGEBRA_SOLVER_NEWTON_HPP
2-
#define STAN_MATH_REV_FUNCTOR_ALGEBRA_SOLVER_NEWTON_HPP
1+
#ifndef STAN_MATH_REV_FUNCTOR_SOLVE_NEWTON_HPP
2+
#define STAN_MATH_REV_FUNCTOR_SOLVE_NEWTON_HPP
33

44
#include <stan/math/rev/core.hpp>
55
#include <stan/math/rev/functor/algebra_system.hpp>
@@ -56,21 +56,19 @@ namespace math {
5656
template <typename F, typename T, typename... Args,
5757
require_eigen_vector_t<T>* = nullptr,
5858
require_all_st_arithmetic<Args...>* = nullptr>
59-
Eigen::VectorXd algebra_solver_newton_impl(const F& f, const T& x,
60-
std::ostream* const msgs,
61-
const double scaling_step_size,
62-
const double function_tolerance,
63-
const int64_t max_num_steps,
64-
const Args&... args) {
59+
Eigen::VectorXd solve_newton_tol(const F& f, const T& x,
60+
const double scaling_step_size,
61+
const double function_tolerance,
62+
const int64_t max_num_steps,
63+
std::ostream* const msgs,
64+
const Args&... args) {
6565
const auto& x_ref = to_ref(value_of(x));
6666

67-
check_nonzero_size("algebra_solver_newton", "initial guess", x_ref);
68-
check_finite("algebra_solver_newton", "initial guess", x_ref);
69-
check_nonnegative("algebra_solver_newton", "scaling_step_size",
70-
scaling_step_size);
71-
check_nonnegative("algebra_solver_newton", "function_tolerance",
72-
function_tolerance);
73-
check_positive("algebra_solver_newton", "max_num_steps", max_num_steps);
67+
check_nonzero_size("solve_newton", "initial guess", x_ref);
68+
check_finite("solve_newton", "initial guess", x_ref);
69+
check_nonnegative("solve_newton", "scaling_step_size", scaling_step_size);
70+
check_nonnegative("solve_newton", "function_tolerance", function_tolerance);
71+
check_positive("solve_newton", "max_num_steps", max_num_steps);
7472

7573
return kinsol_solve(f, x_ref, scaling_step_size, function_tolerance,
7674
max_num_steps, 1, 10, KIN_LINESEARCH, msgs, args...);
@@ -137,10 +135,10 @@ Eigen::VectorXd algebra_solver_newton_impl(const F& f, const T& x,
137135
template <typename F, typename T, typename... T_Args,
138136
require_eigen_vector_t<T>* = nullptr,
139137
require_any_st_var<T_Args...>* = nullptr>
140-
Eigen::Matrix<var, Eigen::Dynamic, 1> algebra_solver_newton_impl(
141-
const F& f, const T& x, std::ostream* const msgs,
142-
const double scaling_step_size, const double function_tolerance,
143-
const int64_t max_num_steps, const T_Args&... args) {
138+
Eigen::Matrix<var, Eigen::Dynamic, 1> solve_newton_tol(
139+
const F& f, const T& x, const double scaling_step_size,
140+
const double function_tolerance, const int64_t max_num_steps,
141+
std::ostream* const msgs, const T_Args&... args) {
144142
const auto& x_ref = to_ref(value_of(x));
145143
auto arena_args_tuple = make_chainable_ptr(std::make_tuple(eval(args)...));
146144
auto args_vals_tuple = math::apply(
@@ -149,13 +147,11 @@ Eigen::Matrix<var, Eigen::Dynamic, 1> algebra_solver_newton_impl(
149147
},
150148
*arena_args_tuple);
151149

152-
check_nonzero_size("algebra_solver_newton", "initial guess", x_ref);
153-
check_finite("algebra_solver_newton", "initial guess", x_ref);
154-
check_nonnegative("algebra_solver_newton", "scaling_step_size",
155-
scaling_step_size);
156-
check_nonnegative("algebra_solver_newton", "function_tolerance",
157-
function_tolerance);
158-
check_positive("algebra_solver_newton", "max_num_steps", max_num_steps);
150+
check_nonzero_size("solve_newton", "initial guess", x_ref);
151+
check_finite("solve_newton", "initial guess", x_ref);
152+
check_nonnegative("solve_newton", "scaling_step_size", scaling_step_size);
153+
check_nonnegative("solve_newton", "function_tolerance", function_tolerance);
154+
check_positive("solve_newton", "max_num_steps", max_num_steps);
159155

160156
// Solve the system
161157
Eigen::VectorXd theta_dbl = math::apply(
@@ -204,6 +200,56 @@ Eigen::Matrix<var, Eigen::Dynamic, 1> algebra_solver_newton_impl(
204200
return ret_type(ret);
205201
}
206202

203+
/**
204+
* Return the solution to the specified system of algebraic
205+
* equations given an initial guess, and parameters and data,
206+
* which get passed into the algebraic system. Use the
207+
* KINSOL solver from the SUNDIALS suite.
208+
*
209+
* This signature does not give users control over the tuning parameters
210+
* and instead relies on default values.
211+
*
212+
* @tparam F type of equation system function
213+
* @tparam T type of elements in the x vector
214+
* @tparam Args types of additional input to the equation system functor
215+
*
216+
* @param[in] f Functor that evaluates the system of equations.
217+
* @param[in] x Vector of starting values (initial guess).
218+
* @param[in, out] msgs The print stream for warning messages.
219+
* @param[in] scaling_step_size Scaled-step stopping tolerance. If
220+
* a Newton step is smaller than the scaling step
221+
* tolerance, the code breaks, assuming the solver is no
222+
* longer making significant progress (i.e. is stuck)
223+
* @param[in] function_tolerance determines whether roots are acceptable.
224+
* @param[in] max_num_steps maximum number of function evaluations.
225+
* @param[in, out] msgs the print stream for warning messages.
226+
* @param[in] args Additional parameters to the equation system functor.
227+
* @return theta Vector of solutions to the system of equations.
228+
* @throw <code>std::invalid_argument</code> if x has size zero.
229+
* @throw <code>std::invalid_argument</code> if x has non-finite elements.
230+
* @throw <code>std::invalid_argument</code> if scaled_step_size is strictly
231+
* negative.
232+
* @throw <code>std::invalid_argument</code> if function_tolerance is strictly
233+
* negative.
234+
* @throw <code>std::invalid_argument</code> if max_num_steps is not positive.
235+
* @throw <code>std::domain_error if solver exceeds max_num_steps.
236+
*/
237+
template <typename F, typename T, typename... T_Args,
238+
require_eigen_vector_t<T>* = nullptr>
239+
Eigen::Matrix<stan::return_type_t<T_Args...>, Eigen::Dynamic, 1> solve_newton(
240+
const F& f, const T& x, std::ostream* const msgs, const T_Args&... args) {
241+
double scaling_step_size = 1e-3;
242+
double function_tolerance = 1e-6;
243+
int64_t max_num_steps = 200;
244+
const auto& args_ref_tuple = std::make_tuple(to_ref(args)...);
245+
return math::apply(
246+
[&](const auto&... args_refs) {
247+
return solve_newton_tol(f, x, scaling_step_size, function_tolerance,
248+
max_num_steps, msgs, args_refs...);
249+
},
250+
args_ref_tuple);
251+
}
252+
207253
/**
208254
* Return the solution to the specified system of algebraic
209255
* equations given an initial guess, and parameters and data,
@@ -252,9 +298,9 @@ Eigen::Matrix<scalar_type_t<T2>, Eigen::Dynamic, 1> algebra_solver_newton(
252298
const double scaling_step_size = 1e-3,
253299
const double function_tolerance = 1e-6,
254300
const long int max_num_steps = 200) { // NOLINT(runtime/int)
255-
return algebra_solver_newton_impl(algebra_solver_adapter<F>(f), x, msgs,
256-
scaling_step_size, function_tolerance,
257-
max_num_steps, y, dat, dat_int);
301+
return solve_newton_tol(algebra_solver_adapter<F>(f), x, scaling_step_size,
302+
function_tolerance, max_num_steps, msgs, y, dat,
303+
dat_int);
258304
}
259305

260306
} // namespace math

stan/math/rev/functor/algebra_solver_powell.hpp renamed to stan/math/rev/functor/solve_powell.hpp

Lines changed: 54 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef STAN_MATH_REV_FUNCTOR_ALGEBRA_SOLVER_POWELL_HPP
2-
#define STAN_MATH_REV_FUNCTOR_ALGEBRA_SOLVER_POWELL_HPP
1+
#ifndef STAN_MATH_REV_FUNCTOR_SOLVE_POWELL_HPP
2+
#define STAN_MATH_REV_FUNCTOR_SOLVE_POWELL_HPP
33

44
#include <stan/math/rev/meta.hpp>
55
#include <stan/math/rev/core.hpp>
@@ -45,11 +45,10 @@ namespace math {
4545
*/
4646
template <typename F, typename T, typename... Args,
4747
require_eigen_vector_t<T>* = nullptr>
48-
T& algebra_solver_powell_call_solver(const F& f, T& x, std::ostream* const msgs,
49-
const double relative_tolerance,
50-
const double function_tolerance,
51-
const int64_t max_num_steps,
52-
const Args&... args) {
48+
T& solve_powell_call_solver(const F& f, T& x, std::ostream* const msgs,
49+
const double relative_tolerance,
50+
const double function_tolerance,
51+
const int64_t max_num_steps, const Args&... args) {
5352
// Construct the solver
5453
hybrj_functor_solver<F> hfs(f);
5554
Eigen::HybridNonLinearSolver<hybrj_functor_solver<F>> solver(hfs);
@@ -103,11 +102,11 @@ T& algebra_solver_powell_call_solver(const F& f, T& x, std::ostream* const msgs,
103102
*
104103
* @param[in] f Functor that evaluates the system of equations.
105104
* @param[in] x Vector of starting values (initial guess).
106-
* @param[in, out] msgs the print stream for warning messages.
107105
* @param[in] relative_tolerance determines the convergence criteria
108106
* for the solution.
109107
* @param[in] function_tolerance determines whether roots are acceptable.
110108
* @param[in] max_num_steps maximum number of function evaluations.
109+
* @param[in, out] msgs the print stream for warning messages.
111110
* @param[in] args additional parameters to the equation system functor.
112111
* @return theta Vector of solutions to the system of equations.
113112
* @pre f returns finite values when passed any value of x and the given args.
@@ -126,12 +125,12 @@ T& algebra_solver_powell_call_solver(const F& f, T& x, std::ostream* const msgs,
126125
template <typename F, typename T, typename... Args,
127126
require_eigen_vector_t<T>* = nullptr,
128127
require_all_st_arithmetic<Args...>* = nullptr>
129-
Eigen::VectorXd algebra_solver_powell_impl(const F& f, const T& x,
130-
std::ostream* const msgs,
131-
const double relative_tolerance,
132-
const double function_tolerance,
133-
const int64_t max_num_steps,
134-
const Args&... args) {
128+
Eigen::VectorXd solve_powell_tol(const F& f, const T& x,
129+
const double relative_tolerance,
130+
const double function_tolerance,
131+
const int64_t max_num_steps,
132+
std::ostream* const msgs,
133+
const Args&... args) {
135134
auto x_ref = eval(value_of(x));
136135
auto args_vals_tuple = std::make_tuple(to_ref(args)...);
137136

@@ -141,20 +140,18 @@ Eigen::VectorXd algebra_solver_powell_impl(const F& f, const T& x,
141140
args_vals_tuple);
142141
};
143142

144-
check_nonzero_size("algebra_solver_powell", "initial guess", x_ref);
145-
check_finite("algebra_solver_powell", "initial guess", x_ref);
143+
check_nonzero_size("solve_powell", "initial guess", x_ref);
144+
check_finite("solve_powell", "initial guess", x_ref);
146145
check_nonnegative("alegbra_solver_powell", "relative_tolerance",
147146
relative_tolerance);
148-
check_nonnegative("algebra_solver_powell", "function_tolerance",
149-
function_tolerance);
150-
check_positive("algebra_solver_powell", "max_num_steps", max_num_steps);
151-
check_matching_sizes("algebra_solver", "the algebraic system's output",
147+
check_nonnegative("solve_powell", "function_tolerance", function_tolerance);
148+
check_positive("solve_powell", "max_num_steps", max_num_steps);
149+
check_matching_sizes("solve_powell", "the algebraic system's output",
152150
f_wrt_x(x_ref), "the vector of unknowns, x,", x_ref);
153151

154152
// Solve the system
155-
return algebra_solver_powell_call_solver(f_wrt_x, x_ref, msgs,
156-
relative_tolerance,
157-
function_tolerance, max_num_steps);
153+
return solve_powell_call_solver(f_wrt_x, x_ref, msgs, relative_tolerance,
154+
function_tolerance, max_num_steps);
158155
}
159156

160157
/**
@@ -163,31 +160,20 @@ Eigen::VectorXd algebra_solver_powell_impl(const F& f, const T& x,
163160
* which get passed into the algebraic system.
164161
* Use Powell's dogleg solver.
165162
*
166-
* The user can also specify the relative tolerance
167-
* (xtol in Eigen's code), the function tolerance,
168-
* and the maximum number of steps (maxfev in Eigen's code).
163+
* This signature does not let the user specify the tuning parameters of the
164+
* solver (instead default values are used).
169165
*
170166
* @tparam F type of equation system function
171-
* @tparam T1 type of elements in the x vector
172-
* @tparam T2 type of elements in the y vector
167+
* @tparam T type of elements in the x vector
168+
* @tparam Args types of additional input to the equation system functor
173169
*
174170
* @param[in] f Functor that evaluates the system of equations.
175171
* @param[in] x Vector of starting values (initial guess).
176-
* @param[in] y parameter vector for the equation system.
177-
* @param[in] dat continuous data vector for the equation system.
178-
* @param[in] dat_int integer data vector for the equation system.
179172
* @param[in, out] msgs the print stream for warning messages.
180-
* @param[in] relative_tolerance determines the convergence criteria
181-
* for the solution.
182-
* @param[in] function_tolerance determines whether roots are acceptable.
183-
* @param[in] max_num_steps maximum number of function evaluations.
173+
* @param[in] args Additional parameters to the equation system functor.
184174
* @return theta Vector of solutions to the system of equations.
185175
* @throw <code>std::invalid_argument</code> if x has size zero.
186176
* @throw <code>std::invalid_argument</code> if x has non-finite elements.
187-
* @throw <code>std::invalid_argument</code> if y has non-finite elements.
188-
* @throw <code>std::invalid_argument</code> if dat has non-finite elements.
189-
* @throw <code>std::invalid_argument</code> if dat_int has non-finite
190-
* elements.
191177
* @throw <code>std::invalid_argument</code> if relative_tolerance is strictly
192178
* negative.
193179
* @throw <code>std::invalid_argument</code> if function_tolerance is strictly
@@ -197,17 +183,20 @@ Eigen::VectorXd algebra_solver_powell_impl(const F& f, const T& x,
197183
* @throw <code>std::domain_error</code> if the norm of the solution exceeds
198184
* the function tolerance.
199185
*/
200-
template <typename F, typename T1, typename T2,
201-
require_all_eigen_vector_t<T1, T2>* = nullptr>
202-
Eigen::Matrix<value_type_t<T2>, Eigen::Dynamic, 1> algebra_solver_powell(
203-
const F& f, const T1& x, const T2& y, const std::vector<double>& dat,
204-
const std::vector<int>& dat_int, std::ostream* const msgs = nullptr,
205-
const double relative_tolerance = 1e-10,
206-
const double function_tolerance = 1e-6,
207-
const int64_t max_num_steps = 1e+3) {
208-
return algebra_solver_powell_impl(algebra_solver_adapter<F>(f), x, msgs,
209-
relative_tolerance, function_tolerance,
210-
max_num_steps, y, dat, dat_int);
186+
template <typename F, typename T, typename... T_Args,
187+
require_eigen_vector_t<T>* = nullptr>
188+
Eigen::Matrix<stan::return_type_t<T_Args...>, Eigen::Dynamic, 1> solve_powell(
189+
const F& f, const T& x, std::ostream* const msgs, const T_Args&... args) {
190+
double relative_tolerance = 1e-10;
191+
double function_tolerance = 1e-6;
192+
int64_t max_num_steps = 200;
193+
const auto& args_ref_tuple = std::make_tuple(to_ref(args)...);
194+
return math::apply(
195+
[&](const auto&... args_refs) {
196+
return solve_powell_tol(f, x, relative_tolerance, function_tolerance,
197+
max_num_steps, msgs, args_refs...);
198+
},
199+
args_ref_tuple);
211200
}
212201

213202
/**
@@ -265,8 +254,9 @@ Eigen::Matrix<value_type_t<T2>, Eigen::Dynamic, 1> algebra_solver(
265254
const double relative_tolerance = 1e-10,
266255
const double function_tolerance = 1e-6,
267256
const int64_t max_num_steps = 1e+3) {
268-
return algebra_solver_powell(f, x, y, dat, dat_int, msgs, relative_tolerance,
269-
function_tolerance, max_num_steps);
257+
return solve_powell_tol(algebra_solver_adapter<F>(f), x, relative_tolerance,
258+
function_tolerance, max_num_steps, msgs, y, dat,
259+
dat_int);
270260
}
271261

272262
/**
@@ -309,11 +299,11 @@ Eigen::Matrix<value_type_t<T2>, Eigen::Dynamic, 1> algebra_solver(
309299
*
310300
* @param[in] f Functor that evaluates the system of equations.
311301
* @param[in] x Vector of starting values (initial guess).
312-
* @param[in, out] msgs the print stream for warning messages.
313302
* @param[in] relative_tolerance determines the convergence criteria
314303
* for the solution.
315304
* @param[in] function_tolerance determines whether roots are acceptable.
316305
* @param[in] max_num_steps maximum number of function evaluations.
306+
* @param[in, out] msgs the print stream for warning messages.
317307
* @param[in] args Additional parameters to the equation system functor.
318308
* @return theta Vector of solutions to the system of equations.
319309
* @pre f returns finite values when passed any value of x and the given args.
@@ -332,10 +322,10 @@ Eigen::Matrix<value_type_t<T2>, Eigen::Dynamic, 1> algebra_solver(
332322
template <typename F, typename T, typename... T_Args,
333323
require_eigen_vector_t<T>* = nullptr,
334324
require_any_st_var<T_Args...>* = nullptr>
335-
Eigen::Matrix<var, Eigen::Dynamic, 1> algebra_solver_powell_impl(
336-
const F& f, const T& x, std::ostream* const msgs,
337-
const double relative_tolerance, const double function_tolerance,
338-
const int64_t max_num_steps, const T_Args&... args) {
325+
Eigen::Matrix<var, Eigen::Dynamic, 1> solve_powell_tol(
326+
const F& f, const T& x, const double relative_tolerance,
327+
const double function_tolerance, const int64_t max_num_steps,
328+
std::ostream* const msgs, const T_Args&... args) {
339329
auto x_ref = eval(value_of(x));
340330
auto arena_args_tuple = make_chainable_ptr(std::make_tuple(eval(args)...));
341331
auto args_vals_tuple = math::apply(
@@ -350,19 +340,18 @@ Eigen::Matrix<var, Eigen::Dynamic, 1> algebra_solver_powell_impl(
350340
args_vals_tuple);
351341
};
352342

353-
check_nonzero_size("algebra_solver_powell", "initial guess", x_ref);
354-
check_finite("algebra_solver_powell", "initial guess", x_ref);
343+
check_nonzero_size("solve_powell", "initial guess", x_ref);
344+
check_finite("solve_powell", "initial guess", x_ref);
355345
check_nonnegative("alegbra_solver_powell", "relative_tolerance",
356346
relative_tolerance);
357-
check_nonnegative("algebra_solver_powell", "function_tolerance",
358-
function_tolerance);
359-
check_positive("algebra_solver_powell", "max_num_steps", max_num_steps);
360-
check_matching_sizes("algebra_solver", "the algebraic system's output",
347+
check_nonnegative("solve_powell", "function_tolerance", function_tolerance);
348+
check_positive("solve_powell", "max_num_steps", max_num_steps);
349+
check_matching_sizes("solve_powell", "the algebraic system's output",
361350
f_wrt_x(x_ref), "the vector of unknowns, x,", x_ref);
362351

363352
// Solve the system
364-
algebra_solver_powell_call_solver(f_wrt_x, x_ref, msgs, relative_tolerance,
365-
function_tolerance, max_num_steps);
353+
solve_powell_call_solver(f_wrt_x, x_ref, msgs, relative_tolerance,
354+
function_tolerance, max_num_steps);
366355

367356
Eigen::MatrixXd Jf_x;
368357
Eigen::VectorXd f_x;

0 commit comments

Comments
 (0)