Skip to content

Commit 34169a6

Browse files
committed
update solve_newton argument ordering.
1 parent 4c4d936 commit 34169a6

2 files changed

Lines changed: 52 additions & 50 deletions

File tree

stan/math/rev/functor/solve_newton.hpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ namespace math {
5656
template <typename F, typename T, typename... Args,
5757
require_eigen_vector_t<T>* = nullptr,
5858
require_all_st_arithmetic<Args...>* = nullptr>
59-
Eigen::VectorXd solve_newton_impl(const F& f, const T& x,
60-
std::ostream* const msgs,
59+
Eigen::VectorXd solve_newton_tol(const F& f, const T& x,
6160
const double scaling_step_size,
6261
const double function_tolerance,
6362
const int64_t max_num_steps,
63+
std::ostream* const msgs,
6464
const Args&... args) {
6565
const auto& x_ref = to_ref(value_of(x));
6666

@@ -135,10 +135,10 @@ Eigen::VectorXd solve_newton_impl(const F& f, const T& x,
135135
template <typename F, typename T, typename... T_Args,
136136
require_eigen_vector_t<T>* = nullptr,
137137
require_any_st_var<T_Args...>* = nullptr>
138-
Eigen::Matrix<var, Eigen::Dynamic, 1> solve_newton_impl(
139-
const F& f, const T& x, std::ostream* const msgs,
140-
const double scaling_step_size, const double function_tolerance,
141-
const int64_t max_num_steps, const T_Args&... args) {
138+
Eigen::Matrix<var, Eigen::Dynamic, 1> solve_newton_tol(
139+
const F& f, const T& x, const double scaling_step_size,
140+
const double function_tolerance, const int64_t max_num_steps,
141+
std::ostream* const msgs, const T_Args&... args) {
142142
const auto& x_ref = to_ref(value_of(x));
143143
auto arena_args_tuple = make_chainable_ptr(std::make_tuple(eval(args)...));
144144
auto args_vals_tuple = math::apply(
@@ -234,21 +234,21 @@ Eigen::Matrix<var, Eigen::Dynamic, 1> solve_newton_impl(
234234
* @throw <code>std::invalid_argument</code> if max_num_steps is not positive.
235235
* @throw <code>std::domain_error if solver exceeds max_num_steps.
236236
*/
237-
template <typename F, typename T, typename... T_Args,
238-
require_eigen_vector_t<T>* = nullptr>
239-
Eigen::Matrix<stan::return_type_t<T_Args...>, Eigen::Dynamic, 1>
240-
solve_newton_tol(const F& f, const T& x, const double scaling_step_size,
241-
const double function_tolerance, const int64_t max_num_steps,
242-
std::ostream* const msgs, const T_Args&... args) {
243-
const auto& args_ref_tuple = std::make_tuple(to_ref(args)...);
244-
return math::apply(
245-
[&](const auto&... args_refs) {
246-
return solve_newton_impl(f, x, msgs,
247-
scaling_step_size, function_tolerance,
248-
max_num_steps, args_refs...);
249-
},
250-
args_ref_tuple);
251-
}
237+
// template <typename F, typename T, typename... T_Args,
238+
// require_eigen_vector_t<T>* = nullptr>
239+
// Eigen::Matrix<stan::return_type_t<T_Args...>, Eigen::Dynamic, 1>
240+
// solve_newton_tol(const F& f, const T& x, const double scaling_step_size,
241+
// const double function_tolerance, const int64_t max_num_steps,
242+
// std::ostream* const msgs, const T_Args&... args) {
243+
// const auto& args_ref_tuple = std::make_tuple(to_ref(args)...);
244+
// return math::apply(
245+
// [&](const auto&... args_refs) {
246+
// return solve_newton_impl(f, x, msgs,
247+
// scaling_step_size, function_tolerance,
248+
// max_num_steps, args_refs...);
249+
// },
250+
// args_ref_tuple);
251+
// }
252252

253253
/**
254254
* Return the solution to the specified system of algebraic
@@ -348,9 +348,9 @@ Eigen::Matrix<scalar_type_t<T2>, Eigen::Dynamic, 1> algebra_solver_newton(
348348
const double scaling_step_size = 1e-3,
349349
const double function_tolerance = 1e-6,
350350
const long int max_num_steps = 200) { // NOLINT(runtime/int)
351-
return solve_newton_impl(algebra_solver_adapter<F>(f), x, msgs,
351+
return solve_newton_tol(algebra_solver_adapter<F>(f), x,
352352
scaling_step_size, function_tolerance, max_num_steps,
353-
y, dat, dat_int);
353+
msgs, y, dat, dat_int);
354354
}
355355

356356
} // namespace math

test/unit/math/rev/functor/solve_newton_test.cpp

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
// Tests for newton solver.
1414

1515
TEST_F(algebra_solver_simple_eq_test, newton_dbl) {
16-
int solver_type = 1;
16+
bool is_newton = true;
1717
Eigen::VectorXd theta
18-
= simple_eq_test(simple_eq_functor(), y_dbl, solver_type);
18+
= simple_eq_test(simple_eq_functor(), y_dbl, is_newton);
1919
}
2020

2121
TEST_F(algebra_solver_simple_eq_test, newton_tuned_dbl) {
22-
int solver_type = 1;
22+
bool is_newton = true;
2323
Eigen::VectorXd theta
24-
= simple_eq_test(simple_eq_functor(), y_dbl, solver_type, true,
24+
= simple_eq_test(simple_eq_functor(), y_dbl, is_newton, true,
2525
scale_step, xtol, ftol, maxfev);
2626
}
2727

@@ -34,22 +34,22 @@ TEST_F(algebra_solver_simple_eq_nopara_test, newton_dbl) {
3434
}
3535

3636
TEST_F(algebra_solver_non_linear_eq_test, newton_dbl) {
37-
int solver_type = 1;
37+
bool is_newton = true;
3838
Eigen::VectorXd theta
39-
= non_linear_eq_test(non_linear_eq_functor(), y_dbl, solver_type);
39+
= non_linear_eq_test(non_linear_eq_functor(), y_dbl, is_newton);
4040
EXPECT_FLOAT_EQ(-y_dbl(0), theta(0));
4141
EXPECT_FLOAT_EQ(-y_dbl(1), theta(1));
4242
EXPECT_FLOAT_EQ(y_dbl(2), theta(2));
4343
}
4444

4545
TEST_F(error_message_test, newton_dbl) {
46-
int solver_type = 1;
47-
error_conditions_test(non_linear_eq_functor(), y_3, solver_type);
46+
bool is_newton = true;
47+
error_conditions_test(non_linear_eq_functor(), y_3, is_newton);
4848
}
4949

5050
TEST_F(max_steps_test, newton_dbl) {
51-
int solver_type = 1;
52-
max_num_steps_test(y, solver_type);
51+
bool is_newton = true;
52+
max_num_steps_test(y, is_newton);
5353
}
5454

5555
TEST(MathMatrixRevMat, unsolvable_flag_newton_dbl) {
@@ -112,12 +112,12 @@ TEST_F(degenerate_eq_test, newton_guess_saddle_point_dbl) {
112112

113113
TEST_F(algebra_solver_simple_eq_test, newton) {
114114
using stan::math::var;
115-
int solver_type = 1;
115+
bool is_newton = true;
116116
for (int k = 0; k < n_x; k++) {
117117
Eigen::Matrix<var, Eigen::Dynamic, 1> y = y_dbl;
118118

119119
Eigen::Matrix<var, Eigen::Dynamic, 1> theta
120-
= simple_eq_test(simple_eq_functor(), y, solver_type);
120+
= simple_eq_test(simple_eq_functor(), y, is_newton);
121121

122122
std::vector<stan::math::var> y_vec{y(0), y(1), y(2)};
123123
std::vector<double> g;
@@ -130,12 +130,12 @@ TEST_F(algebra_solver_simple_eq_test, newton) {
130130

131131
TEST_F(algebra_solver_simple_eq_test, newton_tuned) {
132132
using stan::math::var;
133-
int solver_type = 1;
133+
bool is_newton = true;
134134
for (int k = 0; k < n_x; k++) {
135135
Eigen::Matrix<var, Eigen::Dynamic, 1> y = y_dbl;
136136

137137
Eigen::Matrix<var, Eigen::Dynamic, 1> theta
138-
= simple_eq_test(simple_eq_functor(), y, solver_type, true, scale_step,
138+
= simple_eq_test(simple_eq_functor(), y, is_newton, true, scale_step,
139139
xtol, ftol, maxfev);
140140

141141
std::vector<stan::math::var> y_vec{y(0), y(1), y(2)};
@@ -157,11 +157,11 @@ TEST_F(algebra_solver_simple_eq_test, newton_init_is_para) {
157157

158158
TEST_F(algebra_solver_non_linear_eq_test, newton) {
159159
using stan::math::var;
160-
int solver_type = 1;
160+
bool is_newton = true;
161161
for (int k = 0; k < n_x; k++) {
162162
Eigen::Matrix<var, Eigen::Dynamic, 1> y = y_dbl;
163163
Eigen::Matrix<var, Eigen::Dynamic, 1> theta
164-
= non_linear_eq_test(non_linear_eq_functor(), y, solver_type);
164+
= non_linear_eq_test(non_linear_eq_functor(), y, is_newton);
165165

166166
EXPECT_FLOAT_EQ(-y(0).val(), theta(0).val());
167167
EXPECT_FLOAT_EQ(-y(1).val(), theta(1).val());
@@ -178,14 +178,14 @@ TEST_F(algebra_solver_non_linear_eq_test, newton) {
178178

179179
TEST_F(error_message_test, newton) {
180180
using stan::math::var;
181-
int solver_type = 1;
181+
bool is_newton = true;
182182
Eigen::Matrix<var, Eigen::Dynamic, 1> y = y_2;
183-
error_conditions_test(non_linear_eq_functor(), y, solver_type);
183+
error_conditions_test(non_linear_eq_functor(), y, is_newton);
184184
}
185185

186186
TEST_F(max_steps_test, newton) {
187-
int solver_type = 1;
188-
max_num_steps_test(y_var, solver_type);
187+
bool is_newton = true;
188+
max_num_steps_test(y_var, is_newton);
189189
}
190190

191191
TEST(MathMatrixRevMat, unsolvable_flag_newton) {
@@ -238,7 +238,7 @@ TEST_F(degenerate_eq_test, newton_guess2) {
238238

239239
TEST_F(variadic_test, newton) {
240240
using stan::math::var;
241-
int solver_type = 1;
241+
bool is_newton = true;
242242
bool is_impl = false;
243243
bool use_tol = false;
244244
for (int k = 0; k < n_x; k++) {
@@ -247,7 +247,7 @@ TEST_F(variadic_test, newton) {
247247
var y_3 = y_3_dbl;
248248

249249
Eigen::Matrix<var, Eigen::Dynamic, 1> theta = variadic_eq_impl_test(
250-
A, y_1, y_2, y_3, i, solver_type, is_impl, use_tol, scaling_step_size,
250+
A, y_1, y_2, y_3, i, is_newton, use_tol, scaling_step_size,
251251
relative_tolerance, function_tolerance, max_num_steps);
252252
std::vector<var> y_vec{y_1, y_2, y_3};
253253
std::vector<double> g;
@@ -261,12 +261,13 @@ TEST_F(variadic_test, newton) {
261261
// Additional tests for deprecated signature (with and without tol)
262262
TEST_F(algebra_solver_simple_eq_test, newton_deprecated) {
263263
using stan::math::var;
264-
int solver_type = 3;
264+
bool is_newton = true;
265265
for (int k = 0; k < n_x; k++) {
266266
Eigen::Matrix<var, Eigen::Dynamic, 1> y = y_dbl;
267267

268268
Eigen::Matrix<var, Eigen::Dynamic, 1> theta
269-
= simple_eq_test(simple_eq_functor(), y, solver_type);
269+
= simple_eq_non_varia_test(simple_eq_non_varia_functor(), y,
270+
is_newton);
270271

271272
std::vector<stan::math::var> y_vec{y(0), y(1), y(2)};
272273
std::vector<double> g;
@@ -279,13 +280,14 @@ TEST_F(algebra_solver_simple_eq_test, newton_deprecated) {
279280

280281
TEST_F(algebra_solver_simple_eq_test, newton_tuned_deprecated) {
281282
using stan::math::var;
282-
int solver_type = 3;
283+
bool is_newton = true;
283284
for (int k = 0; k < n_x; k++) {
284285
Eigen::Matrix<var, Eigen::Dynamic, 1> y = y_dbl;
285286

286287
Eigen::Matrix<var, Eigen::Dynamic, 1> theta
287-
= simple_eq_test(simple_eq_functor(), y, solver_type, true, scale_step,
288-
xtol, ftol, maxfev);
288+
= simple_eq_non_varia_test(simple_eq_non_varia_functor(), y,
289+
is_newton, true,
290+
scale_step, xtol, ftol, maxfev);
289291

290292
std::vector<stan::math::var> y_vec{y(0), y(1), y(2)};
291293
std::vector<double> g;

0 commit comments

Comments
 (0)