55import seaborn as sns
66import pandas as pd
77
8- def compute_metrics (data ,
9- date_var ,
10- period ,
11- metric_set ,
12- truth ,
13- estimate ):
8+
9+ def compute_metrics (data , date_var , period , metric_set , truth , estimate ):
1410 """
1511 Compute metrics for given time period
1612
1713 Parameters
1814 ----------
1915 data : DataFrame
2016 Pandas dataframe
21- date_var:
17+ date_var:
2218 Column in `data` containing dates
2319 period:
24- Defining period to group by
20+ Defining period to group by
2521 https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases
2622 metric_set: list
27- List of metrics to compute
23+ List of metrics to compute
2824 https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
2925 truth:
3026 Column identifier for true results
@@ -33,57 +29,56 @@ def compute_metrics(data,
3329
3430 """
3531 output = pd .DataFrame ()
32+ data = data [[date_var , truth , estimate ]]
33+
3634 dates_sorted = data .set_index (date_var ).sort_index ()
3735 rol = dates_sorted .set_index (truth ).rolling (window = period )
3836
3937 for metric in metric_set :
40- rol [metric .__name__ ] = rol .apply (
41- _compute ,
42- raw = False ,
43- kwargs = {
44- 'metric' : metric }
45- )
38+ output [metric .__name__ ] = rol .apply (_compute ,
39+ raw = False ,
40+ kwargs = {"metric" : metric })
4641
47- rol .set_index (date_var )
42+ #BAD
43+ dates_sorted = dates_sorted .reset_index ()
44+ output = pd .concat ([dates_sorted [date_var ],
45+ output .reset_index ().drop (columns = truth )],
46+ axis = 1 ).set_index (date_var )
4847
4948 return output
5049
51- def _compute (window ,
52- metric ):
5350
54- output = metric (y_true = window .index ,
55- y_pred = window )
51+ def _compute (window , metric ):
52+
53+ output = metric (y_true = window .index , y_pred = window )
5654 return output
5755
5856
59- def pin_metrics (board ,
60- df_metrics ,
61- metrics_pin_name ,
62- overwrite = False ):
57+ def pin_metrics (board , df_metrics , metrics_pin_name , overwrite = False ):
6358 """
6459 Update an existing pin storing model metrics over time
6560
6661 Parameters
6762 ----------
68- board :
63+ board :
6964 Pins board
7065 df_metrics: pd.DataFrame
7166 Dataframe of metrics over time, such as created by `vetiver_compute_metrics()`
7267 metrics_pin_name:
7368 Pin name for where the metrics are stored
7469 overwrite: bool
75- If TRUE (the default), overwrite any metrics for
76- dates that exist both in the existing pin and
77- new metrics with the new values. If FALSE, error
78- when the new metrics contain overlapping dates with
70+ If TRUE (the default), overwrite any metrics for
71+ dates that exist both in the existing pin and
72+ new metrics with the new values. If FALSE, error
73+ when the new metrics contain overlapping dates with
7974 the existing pin.
8075 """
8176 date_types = (datetime .date , datetime .time , datetime .datetime )
8277 if not isinstance (df_metrics .index , date_types ):
83- try :
84- df_metrics = df_metrics .index .astype (' datetime' )
78+ try :
79+ df_metrics = df_metrics .index .astype (" datetime" )
8580 except TypeError :
86- raise TypeError ("Index must be a date type" )
81+ raise TypeError (f "Index of { df_metrics } must be a date type" )
8782
8883 new_metrics = df_metrics .sort_index ()
8984
@@ -100,7 +95,8 @@ def pin_metrics(board,
10095 old_metrics = old_metrics not in overlapping_dates
10196 else :
10297 if overlapping_dates :
103- raise ValueError (f"The new metrics overlap with dates \
98+ raise ValueError (
99+ f"The new metrics overlap with dates \
104100 already stored in { repr (metrics_pin_name )} \
105101 Check the aggregated dates or use `overwrite = True`"
106102 )
@@ -111,16 +107,12 @@ def pin_metrics(board,
111107 pins .pin_write (board , new_metrics , metrics_pin_name )
112108
113109
114- def plot_metrics (df_metrics ,
115- date_var , metric ):
110+ def plot_metrics (df_metrics , metric , date_var = None ):
116111
117112 ncols = 1
118113 nrows = len (metric )
119114
120- fig , axes = plt .subplots (
121- nrows = nrows ,
122- ncols = ncols
123- )
115+ fig , axes = plt .subplots (nrows = nrows , ncols = ncols )
124116 """
125117 Plot metrics over a given time period
126118
@@ -134,15 +126,18 @@ def plot_metrics(df_metrics,
134126 List of metrics to compute
135127
136128 """
137-
129+ # date_types = (datetime.date, datetime.time, datetime.datetime)
130+
131+ # if not isinstance(df_metrics.index, date_types):
132+ # try:
133+ # df_metrics = df_metrics.index.astype("datetime")
134+ # except:
135+ # raise TypeError("need datetime")
136+
138137 # loop for plotting each column
139138 for i , col in enumerate (df_metrics .columns ):
140139 sns .lineplot (
141- x = date_var ,
142- y = df_metrics [col ],
143- ax = axes [i ],
144- color = 'royalblue'
145- )\
146- .set_title (col )
147-
148- fig .tight_layout ()
140+ x = df_metrics .index , y = df_metrics [col ], ax = axes [i ], color = "royalblue"
141+ ).set_title (col )
142+
143+ fig .tight_layout ()
0 commit comments