Skip to content

Commit 0d0d9e4

Browse files
committed
Update debug data generation for new-ish constrained types
1 parent 7e0092e commit 0d0d9e4

4 files changed

Lines changed: 396 additions & 265 deletions

File tree

src/analysis_and_optimization/Debug_data_generation.ml

Lines changed: 101 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,15 @@ let gen_num_int m t =
5555
| Transformation.Lower e -> (unwrap_int_exn m e, unwrap_int_exn m e + diff)
5656
| Upper e -> (unwrap_int_exn m e - diff, unwrap_int_exn m e)
5757
| LowerUpper (e1, e2) -> (unwrap_int_exn m e1, unwrap_int_exn m e2)
58-
| _ -> (def_low, def_low + diff) in
58+
| Identity -> (def_low, def_low + diff)
59+
| Ordered | PositiveOrdered | Simplex | UnitVector | SumToZero
60+
|CholeskyCorr | CholeskyCov | Correlation | Covariance | StochasticRow
61+
|StochasticColumn | TupleTransformation _ | Offset _ | Multiplier _
62+
|OffsetMultiplier _ ->
63+
Common.ICE.internal_compiler_error
64+
[%message
65+
"Unknown transformation for int" (t : Expr.Typed.t Transformation.t)]
66+
in
5967
let low = if low = 0 && up <> 1 then low + 1 else low in
6068
Random.int (up - low + 1) + low
6169

@@ -66,7 +74,15 @@ let gen_num_real m t =
6674
| Transformation.Lower e -> (unwrap_num_exn m e, unwrap_num_exn m e +. diff)
6775
| Upper e -> (unwrap_num_exn m e -. diff, unwrap_num_exn m e)
6876
| LowerUpper (e1, e2) -> (unwrap_num_exn m e1, unwrap_num_exn m e2)
69-
| _ -> (def_low, def_low +. diff) in
77+
| Identity | Offset _ | Multiplier _ | OffsetMultiplier _ ->
78+
(def_low, def_low +. diff)
79+
| Ordered | PositiveOrdered | Simplex | UnitVector | SumToZero
80+
|CholeskyCorr | CholeskyCov | Correlation | Covariance | StochasticRow
81+
|StochasticColumn | TupleTransformation _ ->
82+
Common.ICE.internal_compiler_error
83+
[%message
84+
"Unknown transformation for real"
85+
(t : Expr.Typed.t Transformation.t)] in
7086
Random.float_range low up
7187

7288
let rec repeat n e =
@@ -75,49 +91,53 @@ let rec repeat n e =
7591
let rec repeat_th n f =
7692
match n with n when n <= 0 -> [] | m -> f () :: repeat_th (m - 1) f
7793

78-
let gen_bounded m gen e =
79-
match Expr.Helpers.try_unpack (eval_expr m e) with
80-
| Some unpacked_e -> List.map ~f:gen unpacked_e
81-
| None ->
82-
reject e.meta.loc
83-
(Fmt.str "Cannot evaluate bounded (upper OR lower) expr: %a"
84-
Expr.Typed.pp e)
94+
let gen_bounded m n gen e =
95+
let evaled = eval_expr m e in
96+
let exprs =
97+
match Expr.Helpers.try_unpack evaled with
98+
| Some unpacked -> unpacked
99+
| None -> List.init n ~f:(Fn.const evaled) in
100+
List.map ~f:gen exprs
85101

86-
let gen_ul_bounded m gen e1 e2 =
102+
let gen_ul_bounded m n gen e1 e2 =
87103
let create_bounds l u =
88104
List.map2_exn ~f:(fun x y -> gen (Transformation.LowerUpper (x, y))) l u
89105
in
90-
let e1, e2 = (eval_expr m e1, eval_expr m e2) in
91-
match Expr.Helpers.(try_unpack e1, try_unpack e2) with
92-
| Some unpacked_e1, Some unpacked_e2 -> create_bounds unpacked_e1 unpacked_e2
93-
| None, Some unpacked_e2 ->
94-
create_bounds
95-
(List.init (List.length unpacked_e2) ~f:(fun _ -> e1))
96-
unpacked_e2
97-
| Some unpacked_e1, None ->
98-
create_bounds unpacked_e1
99-
(List.init (List.length unpacked_e1) ~f:(fun _ -> e2))
100-
| None, None ->
101-
reject e1.meta.loc
102-
(Fmt.str "Cannot evaluate upper and lower bound expr: %a and %a"
103-
Expr.Typed.pp e1 Expr.Typed.pp e2)
106+
let e1 = eval_expr m e1 in
107+
let es1 =
108+
match Expr.Helpers.try_unpack e1 with
109+
| Some unpacked -> unpacked
110+
| None -> List.init n ~f:(Fn.const e1) in
111+
let e2 = eval_expr m e2 in
112+
let es2 =
113+
match Expr.Helpers.try_unpack e2 with
114+
| Some unpacked -> unpacked
115+
| None -> List.init n ~f:(Fn.const e2) in
116+
create_bounds es1 es2
104117

105118
let gen_row_vector m n t =
106119
match (t : Expr.Typed.t Transformation.t) with
107-
| Transformation.Lower ({meta= {type_= UVector | URowVector; _}; _} as e) ->
108-
gen_bounded m (fun x -> gen_num_real m (Transformation.Lower x)) e
120+
| Transformation.Lower
121+
({meta= {type_= UVector | URowVector | UReal | UInt; _}; _} as e) ->
122+
gen_bounded m n (fun x -> gen_num_real m (Lower x)) e
109123
|> Expr.Helpers.row_vector
110-
| Transformation.Upper ({meta= {type_= UVector | URowVector; _}; _} as e) ->
111-
gen_bounded m (fun x -> gen_num_real m (Transformation.Upper x)) e
124+
| Upper ({meta= {type_= UVector | URowVector | UReal | UInt; _}; _} as e) ->
125+
gen_bounded m n (fun x -> gen_num_real m (Upper x)) e
112126
|> Expr.Helpers.row_vector
113-
| Transformation.LowerUpper
127+
| LowerUpper
114128
( ({meta= {type_= UVector | URowVector | UReal | UInt; _}; _} as e1)
115-
, ({meta= {type_= UVector | URowVector; _}; _} as e2) )
116-
|Transformation.LowerUpper
117-
( ({meta= {type_= UVector | URowVector; _}; _} as e1)
118-
, ({meta= {type_= UReal | UInt; _}; _} as e2) ) ->
119-
gen_ul_bounded m (gen_num_real m) e1 e2 |> Expr.Helpers.row_vector
120-
| _ -> Expr.Helpers.row_vector (repeat_th n (fun _ -> gen_num_real m t))
129+
, ({meta= {type_= UVector | URowVector | UReal | UInt; _}; _} as e2) ) ->
130+
gen_ul_bounded m n (gen_num_real m) e1 e2 |> Expr.Helpers.row_vector
131+
| Identity | Offset _ | Multiplier _ | OffsetMultiplier _ ->
132+
Expr.Helpers.row_vector (repeat_th n (fun _ -> gen_num_real m t))
133+
| Ordered | PositiveOrdered | Simplex | UnitVector | SumToZero
134+
|CholeskyCorr | CholeskyCov | Correlation | Covariance | StochasticRow
135+
|StochasticColumn | TupleTransformation _ | Lower _ | LowerUpper _ | Upper _
136+
->
137+
Common.ICE.internal_compiler_error
138+
[%message
139+
"Unknown transformation for (row) vector"
140+
(t : Expr.Typed.t Transformation.t)]
121141

122142
let gen_vector m n t =
123143
let gen_ordered n =
@@ -146,9 +166,15 @@ let gen_vector m n t =
146166
in
147167
let l = List.map l ~f:(fun x -> x /. sum) in
148168
Expr.Helpers.vector l
149-
| _ ->
150-
let v = Expr.Helpers.unary_op Transpose (gen_row_vector m n t) in
151-
{v with meta= {v.meta with type_= UVector}}
169+
| SumToZero ->
170+
let l = repeat_th n (fun _ -> Random.float 1.) in
171+
let sum = List.fold l ~init:0. ~f:( +. ) in
172+
let l = List.map l ~f:(fun x -> x -. (sum /. float_of_int n)) in
173+
Expr.Helpers.vector l
174+
| Identity | Offset _ | Multiplier _ | OffsetMultiplier _ | CholeskyCorr
175+
|CholeskyCov | Correlation | Covariance | StochasticRow | StochasticColumn
176+
|TupleTransformation _ | Lower _ | LowerUpper _ | Upper _ ->
177+
Expr.Helpers.transpose (gen_row_vector m n t)
152178

153179
let gen_cov_unwrapped n =
154180
let l = repeat_th (n * n) (fun _ -> Random.float 2.) in
@@ -203,6 +229,20 @@ let gen_corr_matrix n =
203229
let corr_chol = gen_corr_cholesky_unwrapped n in
204230
wrap_real_mat (matprod corr_chol (transpose corr_chol))
205231

232+
let gen_sum_to_zero_matrix m n =
233+
(* to make each row and column sum to zero: - make each row sum to zero - add
234+
a new column which is - sum(rest of row) *)
235+
let rows =
236+
repeat_th (m - 1) (fun _ ->
237+
let row = repeat_th (n - 1) (fun _ -> Random.float 2.) in
238+
let row_sum = List.fold row ~init:0. ~f:( +. ) in
239+
row @ [-.row_sum]) in
240+
let col_sums =
241+
List.fold rows ~init:(repeat n 0.) ~f:(fun accum row ->
242+
List.map2_exn accum row ~f:( +. )) in
243+
let last_row = List.map col_sums ~f:(fun x -> -.x) in
244+
wrap_real_mat (rows @ [last_row])
245+
206246
let gen_matrix mm m n t =
207247
match (t : Expr.Typed.t Transformation.t) with
208248
| Covariance -> gen_cov_matrix m
@@ -211,17 +251,33 @@ let gen_matrix mm m n t =
211251
| CholeskyCorr -> gen_corr_cholesky m
212252
| Lower ({meta= {type_= UMatrix; _}; _} as e) ->
213253
Expr.Helpers.matrix_from_rows
214-
(gen_bounded mm (fun x -> gen_row_vector mm n (Lower x)) e)
254+
(gen_bounded mm m (fun x -> gen_row_vector mm n (Lower x)) e)
215255
| Upper ({meta= {type_= UMatrix; _}; _} as e) ->
216256
Expr.Helpers.matrix_from_rows
217-
(gen_bounded mm (fun x -> gen_row_vector mm n (Upper x)) e)
257+
(gen_bounded mm m (fun x -> gen_row_vector mm n (Upper x)) e)
218258
| LowerUpper (({meta= {type_= UMatrix; _}; _} as e1), e2)
219259
|LowerUpper (e1, ({meta= {type_= UMatrix; _}; _} as e2)) ->
220260
Expr.Helpers.matrix_from_rows
221-
(gen_ul_bounded mm (gen_row_vector mm n) e1 e2)
222-
| _ ->
261+
(gen_ul_bounded mm m (gen_row_vector mm n) e1 e2)
262+
| StochasticRow ->
263+
Expr.Helpers.matrix_from_rows
264+
(repeat_th m (fun () ->
265+
Expr.Helpers.transpose (gen_vector mm n Simplex)))
266+
| StochasticColumn ->
267+
Expr.Helpers.transpose
268+
(Expr.Helpers.matrix_from_rows
269+
(repeat_th m (fun () ->
270+
Expr.Helpers.transpose (gen_vector mm n Simplex))))
271+
| SumToZero -> gen_sum_to_zero_matrix m n
272+
| Identity | Lower _ | Upper _ | LowerUpper _ | Offset _ | Multiplier _
273+
|OffsetMultiplier _ ->
223274
Expr.Helpers.matrix_from_rows
224275
(repeat_th m (fun () -> gen_row_vector mm n t))
276+
| Ordered | PositiveOrdered | Simplex | UnitVector | TupleTransformation _ ->
277+
Common.ICE.internal_compiler_error
278+
[%message
279+
"Unknown transformation for matrix"
280+
(t : Expr.Typed.t Transformation.t)]
225281

226282
let gen_complex_unwrapped () =
227283
( gen_num_real Map.Poly.empty Transformation.Identity
@@ -244,13 +300,13 @@ let rec gen_array m st n t =
244300
match (t : Expr.Typed.t Transformation.t) with
245301
| Lower ({meta= {type_= UArray _; _}; _} as e) ->
246302
Expr.Helpers.array_expr
247-
(gen_bounded m (fun x -> generate_value m st (Lower x)) e)
303+
(gen_bounded m n (fun x -> generate_value m st (Lower x)) e)
248304
| Upper ({meta= {type_= UArray _; _}; _} as e) ->
249305
Expr.Helpers.array_expr
250-
(gen_bounded m (fun x -> generate_value m st (Upper x)) e)
306+
(gen_bounded m n (fun x -> generate_value m st (Upper x)) e)
251307
| LowerUpper (({meta= {type_= UArray _; _}; _} as e1), e2)
252308
|LowerUpper (e1, ({meta= {type_= UArray _; _}; _} as e2)) ->
253-
Expr.Helpers.array_expr (gen_ul_bounded m (generate_value m st) e1 e2)
309+
Expr.Helpers.array_expr (gen_ul_bounded m n (generate_value m st) e1 e2)
254310
| _ -> Expr.Helpers.array_expr (repeat_th n elt)
255311

256312
and gen_tuple m st t =

src/middle/Expr.ml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,20 @@ module Helpers = struct
186186
Some l
187187
| _ -> None
188188

189+
let transpose e =
190+
let new_type =
191+
match Typed.type_of e with
192+
| UnsizedType.URowVector -> UnsizedType.UVector
193+
| UVector -> URowVector
194+
| UComplexRowVector -> UComplexVector
195+
| UComplexVector -> UComplexRowVector
196+
| (UMatrix | UComplexMatrix) as t -> t
197+
| t ->
198+
Common.ICE.internal_compiler_error
199+
[%message "Cannot transpose " (t : UnsizedType.t)] in
200+
let expr = unary_op Transpose e in
201+
{expr with meta= {expr.meta with type_= new_type}}
202+
189203
let loop_bottom = one
190204

191205
let internal_funapp fn args meta =

src/middle/Expr.mli

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ module Helpers : sig
7272
val complex_row_vector : (float * float) list -> Typed.t
7373
val complex_vector : (float * float) list -> Typed.t
7474
val complex_matrix_from_rows : Typed.t list -> Typed.t
75+
val transpose : Typed.t -> Typed.t
7576
val array_expr : Typed.t list -> Typed.t
7677
val tuple_expr : Typed.t list -> Typed.t
7778
val try_unpack : Typed.t -> Typed.t list option

0 commit comments

Comments
 (0)