@@ -85,71 +85,70 @@ let gen_num_real m t =
8585 (t : Expr.Typed.t Transformation.t )] in
8686 Random. float_range low up
8787
88- let rec repeat n e =
89- match n with n when n < = 0 -> [] | m -> e :: repeat (m - 1 ) e
88+ let repeat n e = List. init n ~f: (Fn. const e)
89+ let repeat_th n f = List. init n ~f: (fun _ -> f () )
90+ let random_floats n = repeat_th n (fun () -> Random. float 2. )
9091
91- let rec repeat_th n f =
92- match n with n when n < = 0 -> [] | m -> f () :: repeat_th (m - 1 ) f
92+ let unpack_or_repeat n (e : Expr.Typed.t ) : Expr.Typed.t list =
93+ match e.pattern with
94+ | FunApp (CompilerInternal (FnMakeRowVec | FnMakeArray ), l ) -> l
95+ | FunApp
96+ ( StanLib (" Transpose__" , FnPlain , _)
97+ , [{pattern= FunApp (CompilerInternal FnMakeRowVec , l); _}] ) ->
98+ l
99+ | _ -> repeat n e
93100
94101let gen_bounded m n gen e =
95- let evaled = eval_expr m e in
96- let exprs =
97- match Expr.Helpers. try_unpack evaled with
98- | Some unpacked -> unpacked
99- | None -> List. init n ~f: (Fn. const evaled) in
102+ let exprs = unpack_or_repeat n (eval_expr m e) in
100103 List. map ~f: gen exprs
101104
102105let gen_ul_bounded m n gen e1 e2 =
103106 let create_bounds l u =
104107 List. map2_exn ~f: (fun x y -> gen (Transformation. LowerUpper (x, y))) l u
105108 in
106- let e1 = eval_expr m e1 in
107- let es1 =
108- match Expr.Helpers. try_unpack e1 with
109- | Some unpacked -> unpacked
110- | None -> List. init n ~f: (Fn. const e1) in
111- let e2 = eval_expr m e2 in
112- let es2 =
113- match Expr.Helpers. try_unpack e2 with
114- | Some unpacked -> unpacked
115- | None -> List. init n ~f: (Fn. const e2) in
109+ let es1 = unpack_or_repeat n (eval_expr m e1) in
110+ let es2 = unpack_or_repeat n (eval_expr m e2) in
116111 create_bounds es1 es2
117112
118113let gen_row_vector m n t =
119114 match (t : Expr.Typed.t Transformation.t ) with
120- | Transformation. Lower
121- ({meta= {type_= UVector | URowVector | UReal | UInt ; _ }; _ } as e ) ->
115+ | Transformation. Lower e ->
122116 gen_bounded m n (fun x -> gen_num_real m (Lower x)) e
123117 |> Expr.Helpers. row_vector
124- | Upper ( { meta = { type_ = UVector | URowVector | UReal | UInt ; _} ; _} as e ) ->
118+ | Upper e ->
125119 gen_bounded m n (fun x -> gen_num_real m (Upper x)) e
126120 |> Expr.Helpers. row_vector
127- | LowerUpper
128- ( ({meta= {type_= UVector | URowVector | UReal | UInt ; _}; _} as e1)
129- , ({meta= {type_= UVector | URowVector | UReal | UInt ; _ }; _ } as e2 ) ) ->
121+ | LowerUpper (e1 , e2 ) ->
130122 gen_ul_bounded m n (gen_num_real m) e1 e2 |> Expr.Helpers. row_vector
131123 | Identity | Offset _ | Multiplier _ | OffsetMultiplier _ ->
132- Expr.Helpers. row_vector (repeat_th n (fun _ -> gen_num_real m t))
124+ Expr.Helpers. row_vector (repeat_th n (fun () -> gen_num_real m t))
133125 | Ordered | PositiveOrdered | Simplex | UnitVector | SumToZero
134126 | CholeskyCorr | CholeskyCov | Correlation | Covariance | StochasticRow
135- | StochasticColumn | TupleTransformation _ | Lower _ | LowerUpper _ | Upper _
136- ->
127+ | StochasticColumn | TupleTransformation _ ->
137128 Common.ICE. internal_compiler_error
138129 [% message
139- " Unknown transformation for (row) vector "
130+ " Unknown transformation for row_vector "
140131 (t : Expr.Typed.t Transformation.t )]
141132
133+ let sum_to_zero_floats n =
134+ let l = random_floats n in
135+ let sum = List. fold l ~init: 0. ~f: ( +. ) in
136+ List. map l ~f: (fun x -> x -. (sum /. float_of_int n))
137+
138+ let simplex_floats n =
139+ let l = random_floats n in
140+ let sum = List. fold l ~init: 0. ~f: ( +. ) in
141+ List. map l ~f: (fun x -> x /. sum)
142+
142143let gen_vector m n t =
143144 let gen_ordered n =
144- let l = repeat_th n ( fun _ -> Random. float 1. ) in
145+ let l = random_floats n in
145146 List. fold_map l ~init: 0. ~f: (fun accum elt ->
146147 let elt = accum +. elt in
147148 (elt, elt)) in
148149 match t with
149150 | Transformation. Simplex ->
150- let l = repeat_th n (fun _ -> Random. float 1. ) in
151- let sum = List. fold l ~init: 0. ~f: ( +. ) in
152- let l = List. map l ~f: (fun x -> x /. sum) in
151+ let l = simplex_floats n in
153152 Expr.Helpers. vector l
154153 | Ordered ->
155154 let max, l = gen_ordered n in
@@ -159,25 +158,28 @@ let gen_vector m n t =
159158 let _, l = gen_ordered n in
160159 Expr.Helpers. vector l
161160 | UnitVector ->
162- let l = repeat_th n ( fun _ -> Random. float 1. ) in
161+ let l = random_floats n in
163162 let sum =
164163 Float. sqrt
165164 (List. fold l ~init: 0. ~f: (fun accum elt -> accum +. (elt ** 2. )))
166165 in
167166 let l = List. map l ~f: (fun x -> x /. sum) in
168167 Expr.Helpers. vector l
169168 | SumToZero ->
170- let l = repeat_th n (fun _ -> Random. float 1. ) in
171- let sum = List. fold l ~init: 0. ~f: ( +. ) in
172- let l = List. map l ~f: (fun x -> x -. (sum /. float_of_int n)) in
169+ let l = sum_to_zero_floats n in
173170 Expr.Helpers. vector l
174- | Identity | Offset _ | Multiplier _ | OffsetMultiplier _ | CholeskyCorr
175- | CholeskyCov | Correlation | Covariance | StochasticRow | StochasticColumn
176- | TupleTransformation _ | Lower _ | LowerUpper _ | Upper _ ->
171+ | Identity | Offset _ | Multiplier _ | OffsetMultiplier _ | Lower _
172+ | LowerUpper _ | Upper _ ->
177173 Expr.Helpers. transpose (gen_row_vector m n t)
174+ | CholeskyCorr | CholeskyCov | Correlation | Covariance | StochasticRow
175+ | StochasticColumn | TupleTransformation _ ->
176+ Common.ICE. internal_compiler_error
177+ [% message
178+ " Unknown transformation for vector"
179+ (t : Expr.Typed.t Transformation.t )]
178180
179181let gen_cov_unwrapped n =
180- let l = repeat_th (n * n) ( fun _ -> Random. float 2. ) in
182+ let l = random_floats (n * n) in
181183 let l_mat = vect_to_mat l n in
182184 matprod l_mat (transpose l_mat)
183185
@@ -195,22 +197,20 @@ let gen_diag_mat l =
195197let fill_lower_triangular m =
196198 let fill_row i l =
197199 let _, tl = List. split_n l i in
198- List. init ~f: ( fun _ -> Random. float 2. ) i @ tl in
200+ random_floats i @ tl in
199201 List. mapi ~f: fill_row m
200202
201203let pad_mat mm m n =
202- let padding_mat =
203- List. init (m - n) ~f: (fun _ -> List. init n ~f: (fun _ -> Random. float 2. ))
204- in
204+ let padding_mat = repeat_th (m - n) (fun () -> random_floats n) in
205205 wrap_real_mat (mm @ padding_mat)
206206
207207let gen_cov_cholesky m n =
208- let diag_mat = gen_diag_mat (List. init ~f: ( fun _ -> Random. float 2. ) n) in
208+ let diag_mat = gen_diag_mat (random_floats n) in
209209 let filled_mat = fill_lower_triangular diag_mat in
210210 if m < = n then wrap_real_mat filled_mat else pad_mat filled_mat m n
211211
212212let gen_corr_cholesky_unwrapped n =
213- let diag_mat = gen_diag_mat (List. init ~f: ( fun _ -> Random. float 2. ) n) in
213+ let diag_mat = gen_diag_mat (random_floats n) in
214214 let filled_mat = fill_lower_triangular diag_mat in
215215 let row_normalizer l =
216216 let row_norm =
@@ -232,11 +232,7 @@ let gen_corr_matrix n =
232232let gen_sum_to_zero_matrix m n =
233233 (* to make each row and column sum to zero: - make each row sum to zero - add
234234 a new column which is - sum(rest of row) *)
235- let rows =
236- repeat_th (m - 1 ) (fun _ ->
237- let row = repeat_th (n - 1 ) (fun _ -> Random. float 2. ) in
238- let row_sum = List. fold row ~init: 0. ~f: ( +. ) in
239- row @ [-. row_sum]) in
235+ let rows = repeat_th (m - 1 ) (fun () -> sum_to_zero_floats n) in
240236 let col_sums =
241237 List. fold rows ~init: (repeat n 0. ) ~f: (fun accum row ->
242238 List. map2_exn accum row ~f: ( +. )) in
@@ -260,14 +256,9 @@ let gen_matrix mm m n t =
260256 Expr.Helpers. matrix_from_rows
261257 (gen_ul_bounded mm m (gen_row_vector mm n) e1 e2)
262258 | StochasticRow ->
263- Expr.Helpers. matrix_from_rows
264- (repeat_th m (fun () ->
265- Expr.Helpers. transpose (gen_vector mm n Simplex )))
259+ Expr.Helpers. matrix (repeat_th m (fun () -> simplex_floats n))
266260 | StochasticColumn ->
267- Expr.Helpers. transpose
268- (Expr.Helpers. matrix_from_rows
269- (repeat_th m (fun () ->
270- Expr.Helpers. transpose (gen_vector mm n Simplex ))))
261+ Expr.Helpers. matrix (transpose (repeat_th n (fun () -> simplex_floats m)))
271262 | SumToZero -> gen_sum_to_zero_matrix m n
272263 | Identity | Lower _ | Upper _ | LowerUpper _ | Offset _ | Multiplier _
273264 | OffsetMultiplier _ ->
0 commit comments