Skip to content

Commit e9fe3a2

Browse files
committed
Debug data clean up
1 parent 0d0d9e4 commit e9fe3a2

4 files changed

Lines changed: 77 additions & 96 deletions

File tree

src/analysis_and_optimization/Debug_data_generation.ml

Lines changed: 50 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -85,71 +85,70 @@ let gen_num_real m t =
8585
(t : Expr.Typed.t Transformation.t)] in
8686
Random.float_range low up
8787

88-
let rec repeat n e =
89-
match n with n when n <= 0 -> [] | m -> e :: repeat (m - 1) e
88+
let repeat n e = List.init n ~f:(Fn.const e)
89+
let repeat_th n f = List.init n ~f:(fun _ -> f ())
90+
let random_floats n = repeat_th n (fun () -> Random.float 2.)
9091

91-
let rec repeat_th n f =
92-
match n with n when n <= 0 -> [] | m -> f () :: repeat_th (m - 1) f
92+
let unpack_or_repeat n (e : Expr.Typed.t) : Expr.Typed.t list =
93+
match e.pattern with
94+
| FunApp (CompilerInternal (FnMakeRowVec | FnMakeArray), l) -> l
95+
| FunApp
96+
( StanLib ("Transpose__", FnPlain, _)
97+
, [{pattern= FunApp (CompilerInternal FnMakeRowVec, l); _}] ) ->
98+
l
99+
| _ -> repeat n e
93100

94101
let gen_bounded m n gen e =
95-
let evaled = eval_expr m e in
96-
let exprs =
97-
match Expr.Helpers.try_unpack evaled with
98-
| Some unpacked -> unpacked
99-
| None -> List.init n ~f:(Fn.const evaled) in
102+
let exprs = unpack_or_repeat n (eval_expr m e) in
100103
List.map ~f:gen exprs
101104

102105
let gen_ul_bounded m n gen e1 e2 =
103106
let create_bounds l u =
104107
List.map2_exn ~f:(fun x y -> gen (Transformation.LowerUpper (x, y))) l u
105108
in
106-
let e1 = eval_expr m e1 in
107-
let es1 =
108-
match Expr.Helpers.try_unpack e1 with
109-
| Some unpacked -> unpacked
110-
| None -> List.init n ~f:(Fn.const e1) in
111-
let e2 = eval_expr m e2 in
112-
let es2 =
113-
match Expr.Helpers.try_unpack e2 with
114-
| Some unpacked -> unpacked
115-
| None -> List.init n ~f:(Fn.const e2) in
109+
let es1 = unpack_or_repeat n (eval_expr m e1) in
110+
let es2 = unpack_or_repeat n (eval_expr m e2) in
116111
create_bounds es1 es2
117112

118113
let gen_row_vector m n t =
119114
match (t : Expr.Typed.t Transformation.t) with
120-
| Transformation.Lower
121-
({meta= {type_= UVector | URowVector | UReal | UInt; _}; _} as e) ->
115+
| Transformation.Lower e ->
122116
gen_bounded m n (fun x -> gen_num_real m (Lower x)) e
123117
|> Expr.Helpers.row_vector
124-
| Upper ({meta= {type_= UVector | URowVector | UReal | UInt; _}; _} as e) ->
118+
| Upper e ->
125119
gen_bounded m n (fun x -> gen_num_real m (Upper x)) e
126120
|> Expr.Helpers.row_vector
127-
| LowerUpper
128-
( ({meta= {type_= UVector | URowVector | UReal | UInt; _}; _} as e1)
129-
, ({meta= {type_= UVector | URowVector | UReal | UInt; _}; _} as e2) ) ->
121+
| LowerUpper (e1, e2) ->
130122
gen_ul_bounded m n (gen_num_real m) e1 e2 |> Expr.Helpers.row_vector
131123
| Identity | Offset _ | Multiplier _ | OffsetMultiplier _ ->
132-
Expr.Helpers.row_vector (repeat_th n (fun _ -> gen_num_real m t))
124+
Expr.Helpers.row_vector (repeat_th n (fun () -> gen_num_real m t))
133125
| Ordered | PositiveOrdered | Simplex | UnitVector | SumToZero
134126
|CholeskyCorr | CholeskyCov | Correlation | Covariance | StochasticRow
135-
|StochasticColumn | TupleTransformation _ | Lower _ | LowerUpper _ | Upper _
136-
->
127+
|StochasticColumn | TupleTransformation _ ->
137128
Common.ICE.internal_compiler_error
138129
[%message
139-
"Unknown transformation for (row) vector"
130+
"Unknown transformation for row_vector"
140131
(t : Expr.Typed.t Transformation.t)]
141132

133+
let sum_to_zero_floats n =
134+
let l = random_floats n in
135+
let sum = List.fold l ~init:0. ~f:( +. ) in
136+
List.map l ~f:(fun x -> x -. (sum /. float_of_int n))
137+
138+
let simplex_floats n =
139+
let l = random_floats n in
140+
let sum = List.fold l ~init:0. ~f:( +. ) in
141+
List.map l ~f:(fun x -> x /. sum)
142+
142143
let gen_vector m n t =
143144
let gen_ordered n =
144-
let l = repeat_th n (fun _ -> Random.float 1.) in
145+
let l = random_floats n in
145146
List.fold_map l ~init:0. ~f:(fun accum elt ->
146147
let elt = accum +. elt in
147148
(elt, elt)) in
148149
match t with
149150
| Transformation.Simplex ->
150-
let l = repeat_th n (fun _ -> Random.float 1.) in
151-
let sum = List.fold l ~init:0. ~f:( +. ) in
152-
let l = List.map l ~f:(fun x -> x /. sum) in
151+
let l = simplex_floats n in
153152
Expr.Helpers.vector l
154153
| Ordered ->
155154
let max, l = gen_ordered n in
@@ -159,25 +158,28 @@ let gen_vector m n t =
159158
let _, l = gen_ordered n in
160159
Expr.Helpers.vector l
161160
| UnitVector ->
162-
let l = repeat_th n (fun _ -> Random.float 1.) in
161+
let l = random_floats n in
163162
let sum =
164163
Float.sqrt
165164
(List.fold l ~init:0. ~f:(fun accum elt -> accum +. (elt ** 2.)))
166165
in
167166
let l = List.map l ~f:(fun x -> x /. sum) in
168167
Expr.Helpers.vector l
169168
| SumToZero ->
170-
let l = repeat_th n (fun _ -> Random.float 1.) in
171-
let sum = List.fold l ~init:0. ~f:( +. ) in
172-
let l = List.map l ~f:(fun x -> x -. (sum /. float_of_int n)) in
169+
let l = sum_to_zero_floats n in
173170
Expr.Helpers.vector l
174-
| Identity | Offset _ | Multiplier _ | OffsetMultiplier _ | CholeskyCorr
175-
|CholeskyCov | Correlation | Covariance | StochasticRow | StochasticColumn
176-
|TupleTransformation _ | Lower _ | LowerUpper _ | Upper _ ->
171+
| Identity | Offset _ | Multiplier _ | OffsetMultiplier _ | Lower _
172+
|LowerUpper _ | Upper _ ->
177173
Expr.Helpers.transpose (gen_row_vector m n t)
174+
| CholeskyCorr | CholeskyCov | Correlation | Covariance | StochasticRow
175+
|StochasticColumn | TupleTransformation _ ->
176+
Common.ICE.internal_compiler_error
177+
[%message
178+
"Unknown transformation for vector"
179+
(t : Expr.Typed.t Transformation.t)]
178180

179181
let gen_cov_unwrapped n =
180-
let l = repeat_th (n * n) (fun _ -> Random.float 2.) in
182+
let l = random_floats (n * n) in
181183
let l_mat = vect_to_mat l n in
182184
matprod l_mat (transpose l_mat)
183185

@@ -195,22 +197,20 @@ let gen_diag_mat l =
195197
let fill_lower_triangular m =
196198
let fill_row i l =
197199
let _, tl = List.split_n l i in
198-
List.init ~f:(fun _ -> Random.float 2.) i @ tl in
200+
random_floats i @ tl in
199201
List.mapi ~f:fill_row m
200202

201203
let pad_mat mm m n =
202-
let padding_mat =
203-
List.init (m - n) ~f:(fun _ -> List.init n ~f:(fun _ -> Random.float 2.))
204-
in
204+
let padding_mat = repeat_th (m - n) (fun () -> random_floats n) in
205205
wrap_real_mat (mm @ padding_mat)
206206

207207
let gen_cov_cholesky m n =
208-
let diag_mat = gen_diag_mat (List.init ~f:(fun _ -> Random.float 2.) n) in
208+
let diag_mat = gen_diag_mat (random_floats n) in
209209
let filled_mat = fill_lower_triangular diag_mat in
210210
if m <= n then wrap_real_mat filled_mat else pad_mat filled_mat m n
211211

212212
let gen_corr_cholesky_unwrapped n =
213-
let diag_mat = gen_diag_mat (List.init ~f:(fun _ -> Random.float 2.) n) in
213+
let diag_mat = gen_diag_mat (random_floats n) in
214214
let filled_mat = fill_lower_triangular diag_mat in
215215
let row_normalizer l =
216216
let row_norm =
@@ -232,11 +232,7 @@ let gen_corr_matrix n =
232232
let gen_sum_to_zero_matrix m n =
233233
(* to make each row and column sum to zero: - make each row sum to zero - add
234234
a new column which is - sum(rest of row) *)
235-
let rows =
236-
repeat_th (m - 1) (fun _ ->
237-
let row = repeat_th (n - 1) (fun _ -> Random.float 2.) in
238-
let row_sum = List.fold row ~init:0. ~f:( +. ) in
239-
row @ [-.row_sum]) in
235+
let rows = repeat_th (m - 1) (fun () -> sum_to_zero_floats n) in
240236
let col_sums =
241237
List.fold rows ~init:(repeat n 0.) ~f:(fun accum row ->
242238
List.map2_exn accum row ~f:( +. )) in
@@ -260,14 +256,9 @@ let gen_matrix mm m n t =
260256
Expr.Helpers.matrix_from_rows
261257
(gen_ul_bounded mm m (gen_row_vector mm n) e1 e2)
262258
| StochasticRow ->
263-
Expr.Helpers.matrix_from_rows
264-
(repeat_th m (fun () ->
265-
Expr.Helpers.transpose (gen_vector mm n Simplex)))
259+
Expr.Helpers.matrix (repeat_th m (fun () -> simplex_floats n))
266260
| StochasticColumn ->
267-
Expr.Helpers.transpose
268-
(Expr.Helpers.matrix_from_rows
269-
(repeat_th m (fun () ->
270-
Expr.Helpers.transpose (gen_vector mm n Simplex))))
261+
Expr.Helpers.matrix (transpose (repeat_th n (fun () -> simplex_floats m)))
271262
| SumToZero -> gen_sum_to_zero_matrix m n
272263
| Identity | Lower _ | Upper _ | LowerUpper _ | Offset _ | Multiplier _
273264
|OffsetMultiplier _ ->

src/middle/Expr.ml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,6 @@ module Helpers = struct
177177
; adlevel= TupleAD (List.map ~f:Typed.adlevel_of l) }
178178
; pattern= FunApp (CompilerInternal FnMakeTuple, l) }
179179

180-
let try_unpack e =
181-
match e.pattern with
182-
| FunApp (CompilerInternal (FnMakeRowVec | FnMakeArray), l) -> Some l
183-
| FunApp
184-
( StanLib ("Transpose__", FnPlain, _)
185-
, [{pattern= FunApp (CompilerInternal FnMakeRowVec, l); _}] ) ->
186-
Some l
187-
| _ -> None
188-
189180
let transpose e =
190181
let new_type =
191182
match Typed.type_of e with

src/middle/Expr.mli

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ module Helpers : sig
7575
val transpose : Typed.t -> Typed.t
7676
val array_expr : Typed.t list -> Typed.t
7777
val tuple_expr : Typed.t list -> Typed.t
78-
val try_unpack : Typed.t -> Typed.t list option
7978
val loop_bottom : Typed.t
8079
val internal_funapp : 'a t Internal_fun.t -> 'a t list -> 'a -> 'a t
8180
val contains_fn_kind : ('a t Fun_kind.t -> bool) -> ?init:bool -> 'a t -> bool

test/unit/Debug_data_generation_tests.ml

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,10 @@ let%expect_test "whole program data generation check" =
213213
0.090541396473787575, 0.043192465951095979, 0.12555826151802352,
214214
0.10919243895118017, 0.0040762414708886826, 0.015162211321058825
215215
],
216-
"j": [ -0.11142279798014265, 0.35444916424955808 ],
216+
"j": [ -0.2228455959602853, 0.70889832849911616 ],
217217
"k": [
218-
0.896716387037453, 1.7463321550725821, 1.8890769581764726,
219-
2.6141237355067388
218+
1.7934327740749061, 3.4926643101451642, 3.7781539163529452,
219+
5.2282474710134776
220220
],
221221
"l": [
222222
[ 0.38021657248665253, 0.0, 0.0, 0.0 ],
@@ -232,59 +232,59 @@ let%expect_test "whole program data generation check" =
232232
]
233233
],
234234
"zv": [
235-
0.131724853063783, 0.056365587188466293, -0.45727981825495007,
236-
0.030082007965824054, -0.081295883564592941, 0.23855136841195879,
237-
0.12102997636010093, 0.31252035187713811, -0.084656401249884272,
238-
-0.26704204179784391
235+
0.263449706127566, 0.11273117437693259, -0.91455963650990013,
236+
0.060164015931648107, -0.16259176712918588, 0.47710273682391757,
237+
0.24205995272020187, 0.62504070375427623, -0.16931280249976854,
238+
-0.53408408359568782
239239
],
240240
"zm": [
241241
[
242-
1.7903779233006683, 1.5847867503908164, 0.68134888421621276,
243-
-4.0565135579076976
242+
-0.09855574381348281, 0.51592882639360615, -0.087489119393748871,
243+
-0.32988396318637464
244244
],
245245
[
246-
1.1194692100409906, 0.88890763701559339, 1.0216031524558131,
247-
-3.0299799995123973
246+
0.49638223131349624, 0.29079105840364439, -0.61264680777095926,
247+
-0.17452648194618137
248248
],
249249
[
250-
0.32715908021625711, 0.35152264435373926, 1.5732645147953737,
251-
-2.25194623936537
250+
0.24160950850524265, 0.37430502394546239, -0.32013904829409362,
251+
-0.29577548415661148
252252
],
253253
[
254-
0.59535627722869866, 0.0662964423288532, 0.4481524391087563,
255-
-1.1098051586663082
254+
0.90249709642995324, -0.07541114113672176, -0.60447097603656719,
255+
-0.22261497925666413
256256
],
257257
[
258-
-3.8323624907866147, -2.8915134740890021, -3.7243689905761559,
259-
10.448244955451774
258+
-1.5419330924352093, -1.1056137676059912, 1.6247459514953688,
259+
1.0228009085458316
260260
]
261261
],
262262
"srm": [
263+
[
264+
0.14524735713946751, 0.2676928246643322, 0.2298449764268167,
265+
0.35721484176938356
266+
],
263267
[
264268
0.013902633661123317, 0.32771602782761389, 0.39502784967666338,
265269
0.26335348883459936
266270
],
267271
[
268272
0.26002298003803503, 0.18673815713314645, 0.33907506675093968,
269273
0.21416379607787894
270-
],
271-
[
272-
0.18154796808376553, 0.6083390995215554, 0.18923430780048708,
273-
0.020878624594191989
274274
]
275275
],
276276
"scm": [
277277
[
278-
0.34768699481768295, 0.0513471326409491, 0.32947167889815893,
279-
0.27149419364320893
278+
0.49112547258566897, 0.43806687475397765, 0.36384673991338035,
279+
0.25606541787153714
280280
],
281281
[
282-
0.23768237689080784, 0.19520318071229228, 0.27412016057501803,
283-
0.29299428182188186
282+
0.28461848662832567, 0.48962438588802881, 0.29982023824581161,
283+
0.35958785716771219
284284
],
285285
[
286-
0.14524735713946751, 0.2676928246643322, 0.2298449764268167,
287-
0.35721484176938356
286+
0.22425604078600533, 0.072308739357993479, 0.3363330218408081,
287+
0.38434672496075062
288288
]
289289
]
290290
} |}]

0 commit comments

Comments
 (0)