Skip to content

Commit dcf67df

Browse files
Fix pow(x, 1/2) -> sqrt(x) miscompilation for integer division
When both operands in the exponent division are UInt, Stan performs integer division (1/2 = 0), so pow(x, 1/2) is actually pow(x, 0) = 1. The partial evaluator was incorrectly rewriting this to sqrt(x). Add a type check to the guard so the simplification only fires when at least one operand is real (i.e., actual floating-point division). Also removes the comment that was flagging this as a known bug. Add test case for pow(theta, 1/2) with pure integer literals.
1 parent ae477a5 commit dcf67df

2 files changed

Lines changed: 5 additions & 3 deletions

File tree

src/analysis_and_optimization/Partial_evaluator.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -704,11 +704,11 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) =
704704
; { pattern=
705705
FunApp (StanLib ("Divide__", FnPlain, mem), [y; z])
706706
; _ } ] )
707-
when is_int 1 y && is_int 2 z ->
707+
when is_int 1 y && is_int 2 z
708+
&& not
709+
(y.meta.type_ = UInt && z.meta.type_ = UInt) ->
708710
let lub_mem = lub_mem_pat [mem] in
709711
FunApp (StanLib ("sqrt", suffix, lub_mem), [x])
710-
(* This is wrong; if both are type UInt the exponent is
711-
rounds down to zero. *)
712712
| ( "square"
713713
, [{pattern= FunApp (StanLib ("sd", FnPlain, mem), [x]); _}] )
714714
->

test/unit/Optimize.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2171,6 +2171,7 @@ model {
21712171
target += pow(theta, 1./2.);
21722172
target += pow(theta, 1/2.);
21732173
target += pow(theta, 1./2);
2174+
target += pow(theta, 1/2);
21742175
target += square(sd(x_vector));
21752176
target += sqrt(2);
21762177
target += sum(square(x_vector - y_vector));
@@ -2338,6 +2339,7 @@ model {
23382339
target += sqrt(34.);
23392340
target += sqrt(34.);
23402341
target += sqrt(34.);
2342+
target += pow(34., 0);
23412343
target += variance(x_vector);
23422344
target += sqrt2();
23432345
target += squared_distance(x_vector, y_vector);

0 commit comments

Comments
 (0)