Skip to content

Commit 39a96dc

Browse files
authored
Merge pull request #1586 from stan-dev/laplace-tol-updates
Update laplace _tol functions for allow_fallthrough argument, tuple of control params
2 parents 5bd454f + 1ac6abd commit 39a96dc

21 files changed

Lines changed: 536 additions & 292 deletions

src/frontend/Semantic_error.ml

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,16 @@ module TypeError = struct
7777

7878
let laplace_tolerance_arg_name n =
7979
match n with
80-
| 1 -> "first control parameter (initial guess)"
81-
| 2 -> "second control parameter (tolerance)"
82-
| 3 -> "third control parameter (max_num_steps)"
83-
| 4 -> "fourth control parameter (hessian_block_size)"
84-
| 5 -> "fifth control parameter (solver)"
85-
| 6 -> "sixth control parameter (max_steps_line_search)"
86-
| n -> Fmt.str "%a control parameter" (Fmt.ordinal ()) n
80+
| 1 -> "first element of the control parameter tuple (initial guess)"
81+
| 2 -> "second element of the control parameter tuple (tolerance)"
82+
| 3 -> "third element of the control parameter tuple (max_num_steps)"
83+
| 4 -> "fourth element of the control parameter tuple (hessian_block_size)"
84+
| 5 -> "fifth element of the control parameter tuple (solver)"
85+
| 6 ->
86+
"sixth element of the control parameter tuple (max_steps_line_search)"
87+
| 7 -> "seventh element of the control parameter tuple (allow_fallthrough)"
88+
| n ->
89+
Fmt.str "%a element of the control parameter tuple" (Fmt.ordinal ()) n
8790

8891
let rec expected_types : UnsizedType.t Common.Nonempty_list.t Fmt.t =
8992
let ust = expected_style UnsizedType.pp in
@@ -230,17 +233,23 @@ module TypeError = struct
230233
an embedded Laplace approximation."
231234
quoted banned_function
232235
| IlltypedLaplaceTooMany (name, n_args) ->
233-
Fmt.pf ppf
234-
"Received %d extra %a at the end of the call to %a.@ Did you mean to \
235-
call the _tol version?"
236+
Fmt.pf ppf "Received %d extra %a at the end of the call to %a.@ %s"
236237
n_args arguments n_args quoted name
237-
(* For tolerances, because these come at the end, we want to update their
238-
position number accordingly, which is why these reimplement some of the
239-
printing from [SignatureMismatch] *)
238+
(if String.is_substring ~substring:"_tol" name then
239+
"Only a single tuple of control parameters is expected."
240+
else if n_args = 1 then "Did you mean to call the _tol version?"
241+
else "Did you mean to call the _tol version with a tuple of these?")
242+
| IlltypedLaplaceTolArgs (name, ArgNumMismatch (_, 0)) ->
243+
Fmt.pf ppf
244+
"Missing control parameter tuple at the end of the call to %a.@ \
245+
Expected a tuple of %a arguments for the control parameters."
246+
quoted name (expected_style Fmt.int)
247+
(List.length Stan_math_signatures.laplace_tolerance_argument_types)
240248
| IlltypedLaplaceTolArgs (name, ArgNumMismatch (_, found)) ->
241249
Fmt.pf ppf
242-
"@[<v>Received %a control %a at the end of the call to %a.@ Expected \
243-
%a arguments for the control parameters instead.@]"
250+
"@[<v>Received a tuple of %a control %a at the end of the call to \
251+
%a.@ Expected tuple of %a arguments for the control parameters \
252+
instead.@]"
244253
(actual_style Fmt.int) found arguments found quoted name
245254
(expected_style Fmt.int)
246255
(List.length Stan_math_signatures.laplace_tolerance_argument_types)

src/frontend/Typechecker.ml

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -716,21 +716,33 @@ let check_function_callable_with_tuple cf tenv caller_id fname
716716
let verify_laplace_control_args loc id (args : typed_expression list) =
717717
match (String.is_substring ~substring:"_tol" id.name, args) with
718718
| false, [] -> ()
719-
| true, _ -> (
720-
let arg_tys = List.map ~f:arg_type args in
719+
| true, [arg] -> (
720+
let arg_tys =
721+
check_texpression_is_tuple arg
722+
("Control arguments for '" ^ id.name ^ "'") in
721723
match
722724
SignatureMismatch.check_compatible_arguments_mod_conv
723725
Stan_math_signatures.laplace_tolerance_argument_types arg_tys
724726
with
725727
| Ok _ -> ()
726728
| Error err ->
727729
let loc =
728-
let which_arg = match err with ArgError (i, _) -> i | _ -> 0 in
729-
List.nth args which_arg
730-
|> Option.value_map ~f:(fun expr -> expr.emeta.loc) ~default:loc
731-
in
730+
let which_arg = match err with ArgError (i, _) -> i - 1 | _ -> 0 in
731+
let elts = match arg.expr with TupleExpr elts -> elts | _ -> [] in
732+
List.nth elts which_arg
733+
|> Option.value_map ~f:(fun e -> e.emeta.loc) ~default:loc in
732734
Semantic_error.illtyped_laplace_tolerance_args loc id.name err
733735
|> error)
736+
| true, [] ->
737+
Semantic_error.illtyped_laplace_tolerance_args loc id.name
738+
(SignatureMismatch.check_compatible_arguments_mod_conv
739+
Stan_math_signatures.laplace_tolerance_argument_types []
740+
|> Result.error |> Option.value_exn)
741+
|> error
742+
| true, _ :: a :: _ ->
743+
Semantic_error.illtyped_laplace_extra_args a.emeta.loc id.name
744+
(List.length args - 1)
745+
|> error
734746
| false, a :: _ ->
735747
Semantic_error.illtyped_laplace_extra_args a.emeta.loc id.name
736748
(List.length args)

src/stan_math_signatures/Stan_math_signatures.ml

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,7 @@ let embedded_laplace_functions =
169169
|> String.Set.of_list
170170

171171
let is_embedded_laplace_fn name =
172-
(* TEMPORARY: remove after https://github.com/stan-dev/math/pull/3250 or
173-
similar is merged *)
174-
(not
175-
(let version = "%%VERSION%%" in
176-
String.equal "v2.38.0" version
177-
|| String.equal "v2.38.0-rc"
178-
(String.sub ~pos:0 ~len:(String.length version - 1) version)))
179-
&& Set.mem embedded_laplace_functions (Utils.stdlib_distribution_name name)
172+
Set.mem embedded_laplace_functions (Utils.stdlib_distribution_name name)
180173

181174
let laplace_helper_lik_args =
182175
[ ( "bernoulli_logit"
@@ -206,7 +199,8 @@ let laplace_tolerance_argument_types =
206199
[ (AutoDiffable, UVector) (* theta_0 *); (DataOnly, UReal) (* tolerance *)
207200
; (DataOnly, UInt) (* max_num_steps *)
208201
; (DataOnly, UInt) (* hessian_block_size *); (DataOnly, UInt) (* solver *)
209-
; (DataOnly, UInt) (* max_steps_line_search *) ]
202+
; (DataOnly, UInt) (* max_steps_line_search *)
203+
; (DataOnly, UInt) (* allow_fallthrough *) ]
210204

211205
let is_special_function_name name =
212206
is_stan_math_variadic_function_name name

test/integration/bad/embedded_laplace/bad_theta0.stan

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ transformed data {
3030

3131
// control parameters for Laplace approximation
3232
real tolerance = 1e-6;
33-
int max_num_steps = 100, hessian_block_size = 1, solver = 1, max_steps_line_search = 0;
33+
int max_num_steps = 100, hessian_block_size = 1, solver = 1, max_steps_line_search = 0, allow_fallthrough = 1;
3434
}
3535
parameters {
3636
real<lower=0> alpha;
@@ -39,6 +39,7 @@ parameters {
3939
}
4040
model {
4141
target += laplace_marginal_tol(ll_function, (eta, log_ye, y),
42-
K_function, (x, n_obs, alpha, rho), theta_0, tolerance, max_num_steps,
43-
hessian_block_size, solver, max_steps_line_search);
42+
K_function, (x, n_obs, alpha, rho),
43+
(theta_0, tolerance, max_num_steps, hessian_block_size,
44+
solver, max_steps_line_search, allow_fallthrough));
4445
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
functions {
2+
// specify negative binomial likelihood with mean offset
3+
real ll_function(vector theta, // latent Gaussian
4+
real eta,
5+
vector log_ye, // mean offset
6+
array[] int y) {
7+
// observed count
8+
return neg_binomial_2_lpmf(y | exp(log_ye + theta), eta);
9+
}
10+
11+
// specify covariance function
12+
matrix K_function(array[] vector x, int n_obs, real alpha, real rho) {
13+
matrix[n_obs, n_obs] K = gp_exp_quad_cov(x, alpha, rho);
14+
for (i in 1 : n_obs)
15+
K[i, i] += 1e-8;
16+
return K;
17+
}
18+
}
19+
data {
20+
int n_obs;
21+
int n_coordinates;
22+
array[n_obs] int y;
23+
vector[n_obs] ye;
24+
array[n_obs] vector[n_coordinates] x;
25+
}
26+
27+
transformed data {
28+
vector[n_obs] log_ye = log(ye);
29+
vector[n_obs] theta_0 = rep_vector(0.0, n_obs); // initial guess
30+
}
31+
parameters {
32+
real<lower=0> alpha;
33+
real<lower=0> rho;
34+
real<lower=0> eta;
35+
}
36+
37+
38+
generated quantities {
39+
vector[n_obs] theta = laplace_latent_rng(ll_function, (eta, log_ye, y),
40+
K_function, (x, n_obs, alpha, rho), theta_0);
41+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
functions {
2+
// specify negative binomial likelihood with mean offset
3+
real ll_function(vector theta, // latent Gaussian
4+
real eta,
5+
vector log_ye, // mean offset
6+
array[] int y) {
7+
// observed count
8+
return neg_binomial_2_lpmf(y | exp(log_ye + theta), eta);
9+
}
10+
11+
// specify covariance function
12+
matrix K_function(array[] vector x, int n_obs, real alpha, real rho) {
13+
matrix[n_obs, n_obs] K = gp_exp_quad_cov(x, alpha, rho);
14+
for (i in 1 : n_obs)
15+
K[i, i] += 1e-8;
16+
return K;
17+
}
18+
}
19+
data {
20+
int n_obs;
21+
int n_coordinates;
22+
array[n_obs] int y;
23+
vector[n_obs] ye;
24+
array[n_obs] vector[n_coordinates] x;
25+
}
26+
27+
transformed data {
28+
vector[n_obs] log_ye = log(ye);
29+
vector[n_obs] theta_0 = rep_vector(0.0, n_obs); // initial guess
30+
}
31+
parameters {
32+
real<lower=0> alpha;
33+
real<lower=0> rho;
34+
real<lower=0> eta;
35+
}
36+
37+
generated quantities {
38+
vector[n_obs] theta = laplace_latent_tol_rng(ll_function, (eta, log_ye, y),
39+
K_function, (x, n_obs, alpha, rho), (theta_0, 1,2,3,4,5,0.1));
40+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
functions {
2+
// specify negative binomial likelihood with mean offset
3+
real ll_function(vector theta, // latent Gaussian
4+
real eta,
5+
vector log_ye, // mean offset
6+
array[] int y) {
7+
// observed count
8+
return neg_binomial_2_lpmf(y | exp(log_ye + theta), eta);
9+
}
10+
11+
// specify covariance function
12+
matrix K_function(array[] vector x, int n_obs, real alpha, real rho) {
13+
matrix[n_obs, n_obs] K = gp_exp_quad_cov(x, alpha, rho);
14+
for (i in 1 : n_obs)
15+
K[i, i] += 1e-8;
16+
return K;
17+
}
18+
}
19+
data {
20+
int n_obs;
21+
int n_coordinates;
22+
array[n_obs] int y;
23+
vector[n_obs] ye;
24+
array[n_obs] vector[n_coordinates] x;
25+
}
26+
27+
transformed data {
28+
vector[n_obs] log_ye = log(ye);
29+
vector[n_obs] theta_0 = rep_vector(0.0, n_obs); // initial guess
30+
}
31+
parameters {
32+
real<lower=0> alpha;
33+
real<lower=0> rho;
34+
real<lower=0> eta;
35+
}
36+
37+
generated quantities {
38+
vector[n_obs] theta = laplace_latent_tol_rng(ll_function, (eta, log_ye, y),
39+
K_function, (x, n_obs, alpha, rho), (theta_0, 1,2,3,4,5,0),1);
40+
}

test/integration/bad/embedded_laplace/bad_tol3.stan

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,5 @@ parameters {
3636

3737
generated quantities {
3838
vector[n_obs] theta = laplace_latent_tol_rng(ll_function, (eta, log_ye, y),
39-
K_function, (x, n_obs, alpha, rho), theta_0, 1,2,3,4);
39+
K_function, (x, n_obs, alpha, rho), (theta_0, 1,2,3,4));
4040
}

test/integration/bad/embedded_laplace/bad_tol4.stan

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ transformed data {
3434
int hessian_block_size = 1;
3535
int solver = 1;
3636
int max_steps_line_search = 0;
37+
int allow_fallthrough = 1;
3738
}
3839
parameters {
3940
real<lower=0> alpha;
@@ -42,7 +43,7 @@ parameters {
4243
}
4344
model {
4445
target += laplace_marginal_tol(ll_function, (eta, log_ye, y),
45-
K_function, (x, n_obs, alpha, rho), theta_0,
46+
K_function, (x, n_obs, alpha, rho), (theta_0,
4647
eta, max_num_steps, hessian_block_size,
47-
solver, max_steps_line_search);
48+
solver, max_steps_line_search, allow_fallthrough));
4849
}

test/integration/bad/embedded_laplace/bad_tol5.stan

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,5 @@ parameters {
3636

3737
generated quantities {
3838
vector[n_obs] theta = laplace_latent_tol_rng(ll_function, (eta, log_ye, y),
39-
K_function, (x, n_obs, alpha, rho), theta_0, 1,2,3,4,5.5);
39+
K_function, (x, n_obs, alpha, rho), (theta_0, 1,2,3,4,5.5,0));
4040
}

0 commit comments

Comments
 (0)