11import pins
2- import matplotlib as plt
2+ import matplotlib .pyplot as plt
3+ import seaborn as sns
4+ import numpy as np
35import pandas as pd
46
5- # def vetiver_compute_metrics(data,
6- # date_var,
7- # .period,
8- # truth, estimate, **kw,
9- # metric_set = yardstick::metrics,
10- # .every = 1L,
11- # .origin = NULL,
12- # .before = 0L,
13- # .after = 0L,
14- # .complete = FALSE):
15-
16- # rlang::check_installed("slider")
17- # metrics_dots = list2(...)
18- # date_var = enquo(date_var)
19- # slider::slide_period_dfr(
20- # data,
21- # .i = data[[quo_name(date_var)]],
22- # .period = .period,
23- # .f = ~ tibble::tibble(
24- # !!date_var := min(.x[[quo_name(date_var)]]),
25- # n = nrow(.x),
26- # metric_set(.x, {{truth}}, {{estimate}}, !!!metrics_dots)
27- # ),
28- # .every = .every,
29- # .origin = .origin,
30- # .before = .before,
31- # .after = .after,
32- # .complete = .complete
33- # )
34-
35-
36- def vetiver_create_pin_metrics (board ,
37- df_metrics ,
38- metrics_pin_name ,
39- #.index = .index,
40- overwrite = True ):
7+ def compute_metrics (data ,
8+ date_var ,
9+ period ,
10+ metric_set ,
11+ truth ,
12+ pred ):
4113
42- # date_var <- quo_name(enquo(date_var)) #enquo?
43-
44- new_metrics = df_metrics .sort ()
45-
46- new_dates = df_metrics .index .unique ()
47-
48- old_metrics = pins .pin_read (board , metrics_pin_name )
49- overlapping_dates = old_metrics .index in new_dates
50-
51- if overwrite is True :
52- old_metrics = old_metrics not in overlapping_dates
53- else :
54- if overlapping_dates :
55- raise ValueError (f"The new metrics overlap with dates \
56- already stored in { repr (metrics_pin_name )} \
57- Check the aggregated dates or use `overwrite = True`"
14+ output = pd .DataFrame ()
15+ rol = data .set_index (date_var )
16+ rol = rol .rolling (window = period )
17+
18+ for metric in metric_set :
19+ rol .apply (
20+ _compute ,
21+ raw = False ,
22+ kwargs = {'df' :output ,
23+ 'truth' : truth ,
24+ 'pred' : pred ,
25+ 'metric' : metric }
5826 )
27+
28+ return output
29+
30+ def _compute (window ,
31+ df ,
32+ truth ,
33+ pred ,
34+ metric ):
35+ output = metric (y_true = window ,
36+ # TO DO: change y pred to correct piece
37+ y_pred = np .random .randint (0 ,1000 ,size = (len (window ),1 )))
38+ df .loc [window .index .max (), metric .__name__ ] = output
39+ return output
40+
41+
42+ # def pin_metrics(board,
43+ # df_metrics,
44+ # metrics_pin_name,
45+ # #.index = .index,
46+ # overwrite = True):
47+
48+ # # date_var <- quo_name(enquo(date_var)) #enquo?
5949
60- new_metrics = old_metrics + df_metrics
61- new_metrics < - vec_slice (
62- new_metrics ,
63- vctrs ::vec_order (new_metrics .index )
64- )
65-
66- pins .pin_write (board , new_metrics , metrics_pin_name )
67-
68-
69-
70- def compute_metrics (data , date_var ,
71- metric_set ,
72- truth_quo ,
73- estimate_quo ,
74- * kw ):
75- index = data .date_var
76- index = min (index )
77-
78- n = len (data )
79-
80- metrics = metric_set (
81- data = data ,
82- truth = truth_quo ,
83- estimate = estimate_quo
84- )
85-
86- tibble ::tibble (
87- .index = index ,
88- .n = n ,
89- metrics
90- )
50+ # new_metrics = df_metrics.sort()
9151
52+ # new_dates = df_metrics.index.unique()
9253
93- # def eval_select_one(col, data, arg, *kw, call = caller_env()):
94- # rlang::check_installed("tidyselect")
95- # check_dots_empty()
54+ # old_metrics = pins.pin_read(board, metrics_pin_name)
55+ # overlapping_dates = old_metrics.index in new_dates
9656
97- # # `col` is a quosure that has its own environment attached
98- # env = empty_env()
57+ # if overwrite is True:
58+ # old_metrics = old_metrics not in overlapping_dates
59+ # else:
60+ # if overlapping_dates:
61+ # raise ValueError(f"The new metrics overlap with dates \
62+ # already stored in {repr(metrics_pin_name)} \
63+ # Check the aggregated dates or use `overwrite = True`"
64+ # )
9965
100- # loc = tidyselect::eval_select(
101- # expr = col,
102- # data = data,
103- # env = env,
104- # error_call = call
66+ # new_metrics = old_metrics + df_metrics
67+ # new_metrics <- vec_slice(
68+ # new_metrics,
69+ # vctrs::vec_order(new_metrics.index)
10570# )
10671
107- # if (length(loc) != 1):
108- # raise ValueError("`{arg}` must specify exactly one column from `data`.")
72+ # pins.pin_write(board, new_metrics, metrics_pin_name)
10973
110- # return loc
11174
75+ def vetiver_plot_metrics (df_metrics ,
76+ date_var ,
77+ metric ):
11278
113- # def vetiver_plot_metrics(df_metrics,
114- # date_var,
115- # estimate = estimate,
116- # metric = metric,
117- # n = n):
79+ ncols = 1
80+ nrows = len (metric )
11881
82+ fig , axes = plt .subplots (
83+ nrows = nrows ,
84+ ncols = ncols
85+ )
11986
120- # plt.plot(x = df_metrics, y = date_var, marker=".")
121-
122- # ggplot2::ggplot(data = df_metrics,
123- # ggplot2::aes({{ date_var }}, {{.estimate}})) +
124- # # ggplot2::geom_line(ggplot2::aes(color = !!.metric), alpha = 0.7) +
125- # # ggplot2::geom_point(ggplot2::aes(color = !!.metric,
126- # # size = {{n}}),
127- # # alpha = 0.9) +
128- # ggplot2::facet_wrap(ggplot2::vars(!!.metric),
129- # scales = "free_y", ncol = 1) +
130- # ggplot2::guides(color = "none") +
131- # ggplot2::labs(x = NULL, y = NULL)
87+ # loop for plotting each column
88+ for i , col in enumerate (df_metrics .columns ):
89+ sns .lineplot (
90+ x = df_metrics [date_var ],
91+ y = df_metrics [col ],
92+ ax = axes [i ],
93+ color = 'royalblue'
94+ )\
95+ .set_title (col )
96+
97+ fig .tight_layout ()
0 commit comments