@@ -55,7 +55,15 @@ let gen_num_int m t =
5555 | Transformation. Lower e -> (unwrap_int_exn m e, unwrap_int_exn m e + diff)
5656 | Upper e -> (unwrap_int_exn m e - diff, unwrap_int_exn m e)
5757 | LowerUpper (e1 , e2 ) -> (unwrap_int_exn m e1, unwrap_int_exn m e2)
58- | _ -> (def_low, def_low + diff) in
58+ | Identity -> (def_low, def_low + diff)
59+ | Ordered | PositiveOrdered | Simplex | UnitVector | SumToZero
60+ | CholeskyCorr | CholeskyCov | Correlation | Covariance | StochasticRow
61+ | StochasticColumn | TupleTransformation _ | Offset _ | Multiplier _
62+ | OffsetMultiplier _ ->
63+ Common.ICE. internal_compiler_error
64+ [% message
65+ " Unknown transformation for int" (t : Expr.Typed.t Transformation.t )]
66+ in
5967 let low = if low = 0 && up <> 1 then low + 1 else low in
6068 Random. int (up - low + 1 ) + low
6169
@@ -66,70 +74,81 @@ let gen_num_real m t =
6674 | Transformation. Lower e -> (unwrap_num_exn m e, unwrap_num_exn m e +. diff)
6775 | Upper e -> (unwrap_num_exn m e -. diff, unwrap_num_exn m e)
6876 | LowerUpper (e1 , e2 ) -> (unwrap_num_exn m e1, unwrap_num_exn m e2)
69- | _ -> (def_low, def_low +. diff) in
77+ | Identity | Offset _ | Multiplier _ | OffsetMultiplier _ ->
78+ (def_low, def_low +. diff)
79+ | Ordered | PositiveOrdered | Simplex | UnitVector | SumToZero
80+ | CholeskyCorr | CholeskyCov | Correlation | Covariance | StochasticRow
81+ | StochasticColumn | TupleTransformation _ ->
82+ Common.ICE. internal_compiler_error
83+ [% message
84+ " Unknown transformation for real"
85+ (t : Expr.Typed.t Transformation.t )] in
7086 Random. float_range low up
7187
72- let rec repeat n e =
73- match n with n when n < = 0 -> [] | m -> e :: repeat (m - 1 ) e
74-
75- let rec repeat_th n f =
76- match n with n when n < = 0 -> [] | m -> f () :: repeat_th (m - 1 ) f
77-
78- let gen_bounded m gen e =
79- match Expr.Helpers. try_unpack (eval_expr m e) with
80- | Some unpacked_e -> List. map ~f: gen unpacked_e
81- | None ->
82- reject e.meta.loc
83- (Fmt. str " Cannot evaluate bounded (upper OR lower) expr: %a"
84- Expr.Typed. pp e)
88+ let repeat n e = List. init n ~f: (Fn. const e)
89+ let repeat_th n f = List. init n ~f: (fun _ -> f () )
90+ let random_floats n = repeat_th n (fun () -> Random. float 2. )
8591
86- let gen_ul_bounded m gen e1 e2 =
92+ let unpack_or_repeat n (e : Expr.Typed.t ) : Expr.Typed.t list =
93+ match e.pattern with
94+ | FunApp (CompilerInternal (FnMakeRowVec | FnMakeArray ), l ) -> l
95+ | FunApp
96+ ( StanLib (" Transpose__" , FnPlain , _)
97+ , [{pattern= FunApp (CompilerInternal FnMakeRowVec , l); _}] ) ->
98+ l
99+ | _ -> repeat n e
100+
101+ let gen_bounded m n gen e =
102+ let exprs = unpack_or_repeat n (eval_expr m e) in
103+ List. map ~f: gen exprs
104+
105+ let gen_ul_bounded m n gen e1 e2 =
87106 let create_bounds l u =
88107 List. map2_exn ~f: (fun x y -> gen (Transformation. LowerUpper (x, y))) l u
89108 in
90- let e1, e2 = (eval_expr m e1, eval_expr m e2) in
91- match Expr.Helpers. (try_unpack e1, try_unpack e2) with
92- | Some unpacked_e1 , Some unpacked_e2 -> create_bounds unpacked_e1 unpacked_e2
93- | None , Some unpacked_e2 ->
94- create_bounds
95- (List. init (List. length unpacked_e2) ~f: (fun _ -> e1))
96- unpacked_e2
97- | Some unpacked_e1 , None ->
98- create_bounds unpacked_e1
99- (List. init (List. length unpacked_e1) ~f: (fun _ -> e2))
100- | None , None ->
101- reject e1.meta.loc
102- (Fmt. str " Cannot evaluate upper and lower bound expr: %a and %a"
103- Expr.Typed. pp e1 Expr.Typed. pp e2)
109+ let es1 = unpack_or_repeat n (eval_expr m e1) in
110+ let es2 = unpack_or_repeat n (eval_expr m e2) in
111+ create_bounds es1 es2
104112
105113let gen_row_vector m n t =
106114 match (t : Expr.Typed.t Transformation.t ) with
107- | Transformation. Lower ( { meta = { type_ = UVector | URowVector ; _} ; _} as e ) ->
108- gen_bounded m (fun x -> gen_num_real m (Transformation. Lower x)) e
115+ | Transformation. Lower e ->
116+ gen_bounded m n (fun x -> gen_num_real m (Lower x)) e
109117 |> Expr.Helpers. row_vector
110- | Transformation. Upper ( { meta = { type_ = UVector | URowVector ; _} ; _} as e ) ->
111- gen_bounded m (fun x -> gen_num_real m (Transformation. Upper x)) e
118+ | Upper e ->
119+ gen_bounded m n (fun x -> gen_num_real m (Upper x)) e
112120 |> Expr.Helpers. row_vector
113- | Transformation. LowerUpper
114- ( ({meta= {type_= UVector | URowVector | UReal | UInt ; _}; _} as e1)
115- , ({meta= {type_= UVector | URowVector ; _}; _} as e2) )
116- | Transformation. LowerUpper
117- ( ({meta= {type_= UVector | URowVector ; _}; _} as e1)
118- , ({meta= {type_= UReal | UInt ; _ }; _ } as e2 ) ) ->
119- gen_ul_bounded m (gen_num_real m) e1 e2 |> Expr.Helpers. row_vector
120- | _ -> Expr.Helpers. row_vector (repeat_th n (fun _ -> gen_num_real m t))
121+ | LowerUpper (e1 , e2 ) ->
122+ gen_ul_bounded m n (gen_num_real m) e1 e2 |> Expr.Helpers. row_vector
123+ | Identity | Offset _ | Multiplier _ | OffsetMultiplier _ ->
124+ Expr.Helpers. row_vector (repeat_th n (fun () -> gen_num_real m t))
125+ | Ordered | PositiveOrdered | Simplex | UnitVector | SumToZero
126+ | CholeskyCorr | CholeskyCov | Correlation | Covariance | StochasticRow
127+ | StochasticColumn | TupleTransformation _ ->
128+ Common.ICE. internal_compiler_error
129+ [% message
130+ " Unknown transformation for row_vector"
131+ (t : Expr.Typed.t Transformation.t )]
132+
133+ let sum_to_zero_floats n =
134+ let l = random_floats n in
135+ let sum = List. fold l ~init: 0. ~f: ( +. ) in
136+ List. map l ~f: (fun x -> x -. (sum /. float_of_int n))
137+
138+ let simplex_floats n =
139+ let l = random_floats n in
140+ let sum = List. fold l ~init: 0. ~f: ( +. ) in
141+ List. map l ~f: (fun x -> x /. sum)
121142
122143let gen_vector m n t =
123144 let gen_ordered n =
124- let l = repeat_th n ( fun _ -> Random. float 1. ) in
145+ let l = random_floats n in
125146 List. fold_map l ~init: 0. ~f: (fun accum elt ->
126147 let elt = accum +. elt in
127148 (elt, elt)) in
128149 match t with
129150 | Transformation. Simplex ->
130- let l = repeat_th n (fun _ -> Random. float 1. ) in
131- let sum = List. fold l ~init: 0. ~f: ( +. ) in
132- let l = List. map l ~f: (fun x -> x /. sum) in
151+ let l = simplex_floats n in
133152 Expr.Helpers. vector l
134153 | Ordered ->
135154 let max, l = gen_ordered n in
@@ -139,19 +158,28 @@ let gen_vector m n t =
139158 let _, l = gen_ordered n in
140159 Expr.Helpers. vector l
141160 | UnitVector ->
142- let l = repeat_th n ( fun _ -> Random. float 1. ) in
161+ let l = random_floats n in
143162 let sum =
144163 Float. sqrt
145164 (List. fold l ~init: 0. ~f: (fun accum elt -> accum +. (elt ** 2. )))
146165 in
147166 let l = List. map l ~f: (fun x -> x /. sum) in
148167 Expr.Helpers. vector l
149- | _ ->
150- let v = Expr.Helpers. unary_op Transpose (gen_row_vector m n t) in
151- {v with meta= {v.meta with type_= UVector }}
168+ | SumToZero ->
169+ let l = sum_to_zero_floats n in
170+ Expr.Helpers. vector l
171+ | Identity | Offset _ | Multiplier _ | OffsetMultiplier _ | Lower _
172+ | LowerUpper _ | Upper _ ->
173+ Expr.Helpers. transpose (gen_row_vector m n t)
174+ | CholeskyCorr | CholeskyCov | Correlation | Covariance | StochasticRow
175+ | StochasticColumn | TupleTransformation _ ->
176+ Common.ICE. internal_compiler_error
177+ [% message
178+ " Unknown transformation for vector"
179+ (t : Expr.Typed.t Transformation.t )]
152180
153181let gen_cov_unwrapped n =
154- let l = repeat_th (n * n) ( fun _ -> Random. float 2. ) in
182+ let l = random_floats (n * n) in
155183 let l_mat = vect_to_mat l n in
156184 matprod l_mat (transpose l_mat)
157185
@@ -169,22 +197,20 @@ let gen_diag_mat l =
169197let fill_lower_triangular m =
170198 let fill_row i l =
171199 let _, tl = List. split_n l i in
172- List. init ~f: ( fun _ -> Random. float 2. ) i @ tl in
200+ random_floats i @ tl in
173201 List. mapi ~f: fill_row m
174202
175203let pad_mat mm m n =
176- let padding_mat =
177- List. init (m - n) ~f: (fun _ -> List. init n ~f: (fun _ -> Random. float 2. ))
178- in
204+ let padding_mat = repeat_th (m - n) (fun () -> random_floats n) in
179205 wrap_real_mat (mm @ padding_mat)
180206
181207let gen_cov_cholesky m n =
182- let diag_mat = gen_diag_mat (List. init ~f: ( fun _ -> Random. float 2. ) n) in
208+ let diag_mat = gen_diag_mat (random_floats n) in
183209 let filled_mat = fill_lower_triangular diag_mat in
184210 if m < = n then wrap_real_mat filled_mat else pad_mat filled_mat m n
185211
186212let gen_corr_cholesky_unwrapped n =
187- let diag_mat = gen_diag_mat (List. init ~f: ( fun _ -> Random. float 2. ) n) in
213+ let diag_mat = gen_diag_mat (random_floats n) in
188214 let filled_mat = fill_lower_triangular diag_mat in
189215 let row_normalizer l =
190216 let row_norm =
@@ -203,6 +229,16 @@ let gen_corr_matrix n =
203229 let corr_chol = gen_corr_cholesky_unwrapped n in
204230 wrap_real_mat (matprod corr_chol (transpose corr_chol))
205231
232+ let gen_sum_to_zero_matrix m n =
233+ (* to make each row and column sum to zero: - make each row sum to zero - add
234+ a new column which is - sum(rest of row) *)
235+ let rows = repeat_th (m - 1 ) (fun () -> sum_to_zero_floats n) in
236+ let col_sums =
237+ List. fold rows ~init: (repeat n 0. ) ~f: (fun accum row ->
238+ List. map2_exn accum row ~f: ( +. )) in
239+ let last_row = List. map col_sums ~f: (fun x -> -. x) in
240+ wrap_real_mat (rows @ [last_row])
241+
206242let gen_matrix mm m n t =
207243 match (t : Expr.Typed.t Transformation.t ) with
208244 | Covariance -> gen_cov_matrix m
@@ -211,17 +247,28 @@ let gen_matrix mm m n t =
211247 | CholeskyCorr -> gen_corr_cholesky m
212248 | Lower ({meta = {type_ = UMatrix ; _} ; _} as e ) ->
213249 Expr.Helpers. matrix_from_rows
214- (gen_bounded mm (fun x -> gen_row_vector mm n (Lower x)) e)
250+ (gen_bounded mm m (fun x -> gen_row_vector mm n (Lower x)) e)
215251 | Upper ({meta = {type_ = UMatrix ; _} ; _} as e ) ->
216252 Expr.Helpers. matrix_from_rows
217- (gen_bounded mm (fun x -> gen_row_vector mm n (Upper x)) e)
253+ (gen_bounded mm m (fun x -> gen_row_vector mm n (Upper x)) e)
218254 | LowerUpper (({meta= {type_= UMatrix ; _}; _} as e1), e2)
219255 | LowerUpper (e1 , ({meta = {type_ = UMatrix ; _} ; _} as e2 )) ->
220256 Expr.Helpers. matrix_from_rows
221- (gen_ul_bounded mm (gen_row_vector mm n) e1 e2)
222- | _ ->
257+ (gen_ul_bounded mm m (gen_row_vector mm n) e1 e2)
258+ | StochasticRow ->
259+ Expr.Helpers. matrix (repeat_th m (fun () -> simplex_floats n))
260+ | StochasticColumn ->
261+ Expr.Helpers. matrix (transpose (repeat_th n (fun () -> simplex_floats m)))
262+ | SumToZero -> gen_sum_to_zero_matrix m n
263+ | Identity | Lower _ | Upper _ | LowerUpper _ | Offset _ | Multiplier _
264+ | OffsetMultiplier _ ->
223265 Expr.Helpers. matrix_from_rows
224266 (repeat_th m (fun () -> gen_row_vector mm n t))
267+ | Ordered | PositiveOrdered | Simplex | UnitVector | TupleTransformation _ ->
268+ Common.ICE. internal_compiler_error
269+ [% message
270+ " Unknown transformation for matrix"
271+ (t : Expr.Typed.t Transformation.t )]
225272
226273let gen_complex_unwrapped () =
227274 ( gen_num_real Map.Poly. empty Transformation. Identity
@@ -244,13 +291,13 @@ let rec gen_array m st n t =
244291 match (t : Expr.Typed.t Transformation.t ) with
245292 | Lower ({meta = {type_ = UArray _ ; _} ; _} as e ) ->
246293 Expr.Helpers. array_expr
247- (gen_bounded m (fun x -> generate_value m st (Lower x)) e)
294+ (gen_bounded m n (fun x -> generate_value m st (Lower x)) e)
248295 | Upper ({meta = {type_ = UArray _ ; _} ; _} as e ) ->
249296 Expr.Helpers. array_expr
250- (gen_bounded m (fun x -> generate_value m st (Upper x)) e)
297+ (gen_bounded m n (fun x -> generate_value m st (Upper x)) e)
251298 | LowerUpper (({meta= {type_= UArray _; _}; _} as e1), e2)
252299 | LowerUpper (e1 , ({meta = {type_ = UArray _ ; _} ; _} as e2 )) ->
253- Expr.Helpers. array_expr (gen_ul_bounded m (generate_value m st) e1 e2)
300+ Expr.Helpers. array_expr (gen_ul_bounded m n (generate_value m st) e1 e2)
254301 | _ -> Expr.Helpers. array_expr (repeat_th n elt)
255302
256303and gen_tuple m st t =
0 commit comments