@@ -55,7 +55,15 @@ let gen_num_int m t =
5555 | Transformation. Lower e -> (unwrap_int_exn m e, unwrap_int_exn m e + diff)
5656 | Upper e -> (unwrap_int_exn m e - diff, unwrap_int_exn m e)
5757 | LowerUpper (e1 , e2 ) -> (unwrap_int_exn m e1, unwrap_int_exn m e2)
58- | _ -> (def_low, def_low + diff) in
58+ | Identity -> (def_low, def_low + diff)
59+ | Ordered | PositiveOrdered | Simplex | UnitVector | SumToZero
60+ | CholeskyCorr | CholeskyCov | Correlation | Covariance | StochasticRow
61+ | StochasticColumn | TupleTransformation _ | Offset _ | Multiplier _
62+ | OffsetMultiplier _ ->
63+ Common.ICE. internal_compiler_error
64+ [% message
65+ " Unknown transformation for int" (t : Expr.Typed.t Transformation.t )]
66+ in
5967 let low = if low = 0 && up <> 1 then low + 1 else low in
6068 Random. int (up - low + 1 ) + low
6169
@@ -66,7 +74,15 @@ let gen_num_real m t =
6674 | Transformation. Lower e -> (unwrap_num_exn m e, unwrap_num_exn m e +. diff)
6775 | Upper e -> (unwrap_num_exn m e -. diff, unwrap_num_exn m e)
6876 | LowerUpper (e1 , e2 ) -> (unwrap_num_exn m e1, unwrap_num_exn m e2)
69- | _ -> (def_low, def_low +. diff) in
77+ | Identity | Offset _ | Multiplier _ | OffsetMultiplier _ ->
78+ (def_low, def_low +. diff)
79+ | Ordered | PositiveOrdered | Simplex | UnitVector | SumToZero
80+ | CholeskyCorr | CholeskyCov | Correlation | Covariance | StochasticRow
81+ | StochasticColumn | TupleTransformation _ ->
82+ Common.ICE. internal_compiler_error
83+ [% message
84+ " Unknown transformation for real"
85+ (t : Expr.Typed.t Transformation.t )] in
7086 Random. float_range low up
7187
7288let rec repeat n e =
@@ -75,49 +91,53 @@ let rec repeat n e =
7591let rec repeat_th n f =
7692 match n with n when n < = 0 -> [] | m -> f () :: repeat_th (m - 1 ) f
7793
78- let gen_bounded m gen e =
79- match Expr.Helpers. try_unpack ( eval_expr m e) with
80- | Some unpacked_e -> List. map ~f: gen unpacked_e
81- | None ->
82- reject e.meta.loc
83- ( Fmt. str " Cannot evaluate bounded (upper OR lower) expr: %a "
84- Expr.Typed. pp e)
94+ let gen_bounded m n gen e =
95+ let evaled = eval_expr m e in
96+ let exprs =
97+ match Expr.Helpers. try_unpack evaled with
98+ | Some unpacked -> unpacked
99+ | None -> List. init n ~f: ( Fn. const evaled) in
100+ List. map ~f: gen exprs
85101
86- let gen_ul_bounded m gen e1 e2 =
102+ let gen_ul_bounded m n gen e1 e2 =
87103 let create_bounds l u =
88104 List. map2_exn ~f: (fun x y -> gen (Transformation. LowerUpper (x, y))) l u
89105 in
90- let e1, e2 = (eval_expr m e1, eval_expr m e2) in
91- match Expr.Helpers. (try_unpack e1, try_unpack e2) with
92- | Some unpacked_e1 , Some unpacked_e2 -> create_bounds unpacked_e1 unpacked_e2
93- | None , Some unpacked_e2 ->
94- create_bounds
95- (List. init (List. length unpacked_e2) ~f: (fun _ -> e1))
96- unpacked_e2
97- | Some unpacked_e1 , None ->
98- create_bounds unpacked_e1
99- (List. init (List. length unpacked_e1) ~f: (fun _ -> e2))
100- | None , None ->
101- reject e1.meta.loc
102- (Fmt. str " Cannot evaluate upper and lower bound expr: %a and %a"
103- Expr.Typed. pp e1 Expr.Typed. pp e2)
106+ let e1 = eval_expr m e1 in
107+ let es1 =
108+ match Expr.Helpers. try_unpack e1 with
109+ | Some unpacked -> unpacked
110+ | None -> List. init n ~f: (Fn. const e1) in
111+ let e2 = eval_expr m e2 in
112+ let es2 =
113+ match Expr.Helpers. try_unpack e2 with
114+ | Some unpacked -> unpacked
115+ | None -> List. init n ~f: (Fn. const e2) in
116+ create_bounds es1 es2
104117
105118let gen_row_vector m n t =
106119 match (t : Expr.Typed.t Transformation.t ) with
107- | Transformation. Lower ({meta = {type_ = UVector | URowVector ; _} ; _} as e ) ->
108- gen_bounded m (fun x -> gen_num_real m (Transformation. Lower x)) e
120+ | Transformation. Lower
121+ ({meta= {type_= UVector | URowVector | UReal | UInt ; _ }; _ } as e ) ->
122+ gen_bounded m n (fun x -> gen_num_real m (Lower x)) e
109123 |> Expr.Helpers. row_vector
110- | Transformation. Upper ({meta = {type_ = UVector | URowVector ; _} ; _} as e ) ->
111- gen_bounded m (fun x -> gen_num_real m (Transformation. Upper x)) e
124+ | Upper ({meta = {type_ = UVector | URowVector | UReal | UInt ; _} ; _} as e ) ->
125+ gen_bounded m n (fun x -> gen_num_real m (Upper x)) e
112126 |> Expr.Helpers. row_vector
113- | Transformation. LowerUpper
127+ | LowerUpper
114128 ( ({meta= {type_= UVector | URowVector | UReal | UInt ; _}; _} as e1)
115- , ({meta= {type_= UVector | URowVector ; _}; _} as e2) )
116- | Transformation. LowerUpper
117- ( ({meta= {type_= UVector | URowVector ; _}; _} as e1)
118- , ({meta= {type_= UReal | UInt ; _ }; _ } as e2 ) ) ->
119- gen_ul_bounded m (gen_num_real m) e1 e2 |> Expr.Helpers. row_vector
120- | _ -> Expr.Helpers. row_vector (repeat_th n (fun _ -> gen_num_real m t))
129+ , ({meta= {type_= UVector | URowVector | UReal | UInt ; _ }; _ } as e2 ) ) ->
130+ gen_ul_bounded m n (gen_num_real m) e1 e2 |> Expr.Helpers. row_vector
131+ | Identity | Offset _ | Multiplier _ | OffsetMultiplier _ ->
132+ Expr.Helpers. row_vector (repeat_th n (fun _ -> gen_num_real m t))
133+ | Ordered | PositiveOrdered | Simplex | UnitVector | SumToZero
134+ | CholeskyCorr | CholeskyCov | Correlation | Covariance | StochasticRow
135+ | StochasticColumn | TupleTransformation _ | Lower _ | LowerUpper _ | Upper _
136+ ->
137+ Common.ICE. internal_compiler_error
138+ [% message
139+ " Unknown transformation for (row) vector"
140+ (t : Expr.Typed.t Transformation.t )]
121141
122142let gen_vector m n t =
123143 let gen_ordered n =
@@ -146,9 +166,15 @@ let gen_vector m n t =
146166 in
147167 let l = List. map l ~f: (fun x -> x /. sum) in
148168 Expr.Helpers. vector l
149- | _ ->
150- let v = Expr.Helpers. unary_op Transpose (gen_row_vector m n t) in
151- {v with meta= {v.meta with type_= UVector }}
169+ | SumToZero ->
170+ let l = repeat_th n (fun _ -> Random. float 1. ) in
171+ let sum = List. fold l ~init: 0. ~f: ( +. ) in
172+ let l = List. map l ~f: (fun x -> x -. (sum /. float_of_int n)) in
173+ Expr.Helpers. vector l
174+ | Identity | Offset _ | Multiplier _ | OffsetMultiplier _ | CholeskyCorr
175+ | CholeskyCov | Correlation | Covariance | StochasticRow | StochasticColumn
176+ | TupleTransformation _ | Lower _ | LowerUpper _ | Upper _ ->
177+ Expr.Helpers. transpose (gen_row_vector m n t)
152178
153179let gen_cov_unwrapped n =
154180 let l = repeat_th (n * n) (fun _ -> Random. float 2. ) in
@@ -203,6 +229,20 @@ let gen_corr_matrix n =
203229 let corr_chol = gen_corr_cholesky_unwrapped n in
204230 wrap_real_mat (matprod corr_chol (transpose corr_chol))
205231
232+ let gen_sum_to_zero_matrix m n =
233+ (* to make each row and column sum to zero: - make each row sum to zero - add
234+ a new column which is - sum(rest of row) *)
235+ let rows =
236+ repeat_th (m - 1 ) (fun _ ->
237+ let row = repeat_th (n - 1 ) (fun _ -> Random. float 2. ) in
238+ let row_sum = List. fold row ~init: 0. ~f: ( +. ) in
239+ row @ [-. row_sum]) in
240+ let col_sums =
241+ List. fold rows ~init: (repeat n 0. ) ~f: (fun accum row ->
242+ List. map2_exn accum row ~f: ( +. )) in
243+ let last_row = List. map col_sums ~f: (fun x -> -. x) in
244+ wrap_real_mat (rows @ [last_row])
245+
206246let gen_matrix mm m n t =
207247 match (t : Expr.Typed.t Transformation.t ) with
208248 | Covariance -> gen_cov_matrix m
@@ -211,17 +251,33 @@ let gen_matrix mm m n t =
211251 | CholeskyCorr -> gen_corr_cholesky m
212252 | Lower ({meta = {type_ = UMatrix ; _} ; _} as e ) ->
213253 Expr.Helpers. matrix_from_rows
214- (gen_bounded mm (fun x -> gen_row_vector mm n (Lower x)) e)
254+ (gen_bounded mm m (fun x -> gen_row_vector mm n (Lower x)) e)
215255 | Upper ({meta = {type_ = UMatrix ; _} ; _} as e ) ->
216256 Expr.Helpers. matrix_from_rows
217- (gen_bounded mm (fun x -> gen_row_vector mm n (Upper x)) e)
257+ (gen_bounded mm m (fun x -> gen_row_vector mm n (Upper x)) e)
218258 | LowerUpper (({meta= {type_= UMatrix ; _}; _} as e1), e2)
219259 | LowerUpper (e1 , ({meta = {type_ = UMatrix ; _} ; _} as e2 )) ->
220260 Expr.Helpers. matrix_from_rows
221- (gen_ul_bounded mm (gen_row_vector mm n) e1 e2)
222- | _ ->
261+ (gen_ul_bounded mm m (gen_row_vector mm n) e1 e2)
262+ | StochasticRow ->
263+ Expr.Helpers. matrix_from_rows
264+ (repeat_th m (fun () ->
265+ Expr.Helpers. transpose (gen_vector mm n Simplex )))
266+ | StochasticColumn ->
267+ Expr.Helpers. transpose
268+ (Expr.Helpers. matrix_from_rows
269+ (repeat_th m (fun () ->
270+ Expr.Helpers. transpose (gen_vector mm n Simplex ))))
271+ | SumToZero -> gen_sum_to_zero_matrix m n
272+ | Identity | Lower _ | Upper _ | LowerUpper _ | Offset _ | Multiplier _
273+ | OffsetMultiplier _ ->
223274 Expr.Helpers. matrix_from_rows
224275 (repeat_th m (fun () -> gen_row_vector mm n t))
276+ | Ordered | PositiveOrdered | Simplex | UnitVector | TupleTransformation _ ->
277+ Common.ICE. internal_compiler_error
278+ [% message
279+ " Unknown transformation for matrix"
280+ (t : Expr.Typed.t Transformation.t )]
225281
226282let gen_complex_unwrapped () =
227283 ( gen_num_real Map.Poly. empty Transformation. Identity
@@ -244,13 +300,13 @@ let rec gen_array m st n t =
244300 match (t : Expr.Typed.t Transformation.t ) with
245301 | Lower ({meta = {type_ = UArray _ ; _} ; _} as e ) ->
246302 Expr.Helpers. array_expr
247- (gen_bounded m (fun x -> generate_value m st (Lower x)) e)
303+ (gen_bounded m n (fun x -> generate_value m st (Lower x)) e)
248304 | Upper ({meta = {type_ = UArray _ ; _} ; _} as e ) ->
249305 Expr.Helpers. array_expr
250- (gen_bounded m (fun x -> generate_value m st (Upper x)) e)
306+ (gen_bounded m n (fun x -> generate_value m st (Upper x)) e)
251307 | LowerUpper (({meta= {type_= UArray _; _}; _} as e1), e2)
252308 | LowerUpper (e1 , ({meta = {type_ = UArray _ ; _} ; _} as e2 )) ->
253- Expr.Helpers. array_expr (gen_ul_bounded m (generate_value m st) e1 e2)
309+ Expr.Helpers. array_expr (gen_ul_bounded m n (generate_value m st) e1 e2)
254310 | _ -> Expr.Helpers. array_expr (repeat_th n elt)
255311
256312and gen_tuple m st t =
0 commit comments