Skip to content

Commit 7bb9e5c

Browse files
committed
add dependencies
1 parent dc64334 commit 7bb9e5c

2 files changed

Lines changed: 45 additions & 48 deletions

File tree

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ install_requires =
3434
requests
3535
pins
3636
rsconnect-python
37+
matplotlib
38+
seaborn
3739

3840
[options.extras_require]
3941
dev =

vetiver/monitor.py

Lines changed: 43 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,22 @@
55
import seaborn as sns
66
import pandas as pd
77

8-
def compute_metrics(data,
9-
date_var,
10-
period,
11-
metric_set,
12-
truth,
13-
estimate):
8+
9+
def compute_metrics(data, date_var, period, metric_set, truth, estimate):
1410
"""
1511
Compute metrics for given time period
1612
1713
Parameters
1814
----------
1915
data : DataFrame
2016
Pandas dataframe
21-
date_var:
17+
date_var:
2218
Column in `data` containing dates
2319
period:
24-
Defining period to group by
20+
Defining period to group by
2521
https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases
2622
metric_set: list
27-
List of metrics to compute
23+
List of metrics to compute
2824
https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
2925
truth:
3026
Column identifier for true results
@@ -33,57 +29,56 @@ def compute_metrics(data,
3329
3430
"""
3531
output = pd.DataFrame()
32+
data = data[[date_var, truth, estimate]]
33+
3634
dates_sorted = data.set_index(date_var).sort_index()
3735
rol = dates_sorted.set_index(truth).rolling(window=period)
3836

3937
for metric in metric_set:
40-
rol[metric.__name__] =rol.apply(
41-
_compute,
42-
raw=False,
43-
kwargs = {
44-
'metric': metric}
45-
)
38+
output[metric.__name__] = rol.apply(_compute,
39+
raw=False,
40+
kwargs={"metric": metric})
4641

47-
rol.set_index(date_var)
42+
#BAD
43+
dates_sorted = dates_sorted.reset_index()
44+
output = pd.concat([dates_sorted[date_var],
45+
output.reset_index().drop(columns=truth)],
46+
axis = 1).set_index(date_var)
4847

4948
return output
5049

51-
def _compute(window,
52-
metric):
5350

54-
output = metric(y_true = window.index,
55-
y_pred=window)
51+
def _compute(window, metric):
52+
53+
output = metric(y_true=window.index, y_pred=window)
5654
return output
5755

5856

59-
def pin_metrics(board,
60-
df_metrics,
61-
metrics_pin_name,
62-
overwrite = False):
57+
def pin_metrics(board, df_metrics, metrics_pin_name, overwrite=False):
6358
"""
6459
Update an existing pin storing model metrics over time
6560
6661
Parameters
6762
----------
68-
board :
63+
board :
6964
Pins board
7065
df_metrics: pd.DataFrame
7166
Dataframe of metrics over time, such as created by `vetiver_compute_metrics()`
7267
metrics_pin_name:
7368
Pin name for where the metrics are stored
7469
overwrite: bool
75-
If TRUE (the default), overwrite any metrics for
76-
dates that exist both in the existing pin and
77-
new metrics with the new values. If FALSE, error
78-
when the new metrics contain overlapping dates with
70+
If TRUE (the default), overwrite any metrics for
71+
dates that exist both in the existing pin and
72+
new metrics with the new values. If FALSE, error
73+
when the new metrics contain overlapping dates with
7974
the existing pin.
8075
"""
8176
date_types = (datetime.date, datetime.time, datetime.datetime)
8277
if not isinstance(df_metrics.index, date_types):
83-
try:
84-
df_metrics = df_metrics.index.astype('datetime')
78+
try:
79+
df_metrics = df_metrics.index.astype("datetime")
8580
except TypeError:
86-
raise TypeError("Index must be a date type")
81+
raise TypeError(f"Index of {df_metrics} must be a date type")
8782

8883
new_metrics = df_metrics.sort_index()
8984

@@ -100,7 +95,8 @@ def pin_metrics(board,
10095
old_metrics = old_metrics not in overlapping_dates
10196
else:
10297
if overlapping_dates:
103-
raise ValueError(f"The new metrics overlap with dates \
98+
raise ValueError(
99+
f"The new metrics overlap with dates \
104100
already stored in {repr(metrics_pin_name)} \
105101
Check the aggregated dates or use `overwrite = True`"
106102
)
@@ -111,16 +107,12 @@ def pin_metrics(board,
111107
pins.pin_write(board, new_metrics, metrics_pin_name)
112108

113109

114-
def plot_metrics(df_metrics,
115-
date_var, metric):
110+
def plot_metrics(df_metrics, metric, date_var = None):
116111

117112
ncols = 1
118113
nrows = len(metric)
119114

120-
fig, axes = plt.subplots(
121-
nrows=nrows,
122-
ncols=ncols
123-
)
115+
fig, axes = plt.subplots(nrows=nrows, ncols=ncols)
124116
"""
125117
Plot metrics over a given time period
126118
@@ -134,15 +126,18 @@ def plot_metrics(df_metrics,
134126
List of metrics to compute
135127
136128
"""
137-
129+
# date_types = (datetime.date, datetime.time, datetime.datetime)
130+
131+
# if not isinstance(df_metrics.index, date_types):
132+
# try:
133+
# df_metrics = df_metrics.index.astype("datetime")
134+
# except:
135+
# raise TypeError("need datetime")
136+
138137
# loop for plotting each column
139138
for i, col in enumerate(df_metrics.columns):
140139
sns.lineplot(
141-
x=date_var,
142-
y=df_metrics[col],
143-
ax=axes[i],
144-
color='royalblue'
145-
)\
146-
.set_title(col)
147-
148-
fig.tight_layout()
140+
x=df_metrics.index, y=df_metrics[col], ax=axes[i], color="royalblue"
141+
).set_title(col)
142+
143+
fig.tight_layout()

0 commit comments

Comments
 (0)