Skip to content

Commit 2451e9c

Browse files
committed
pin_metrics to facet
1 parent 46ebef9 commit 2451e9c

2 files changed

Lines changed: 117 additions & 49 deletions

File tree

vetiver/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
__author__ = "Isabel Zimmerman <isabel.zimmerman@rstudio.com>"
55
__all__ = []
66

7+
<<<<<<< HEAD
78
import importlib # noqa
89
from .ptype import * # noqa
910
from .vetiver_model import VetiverModel # noqa
@@ -20,3 +21,19 @@
2021
from .handlers.torch import TorchHandler # noqa
2122
from .rsconnect import deploy_rsconnect # noqa
2223
from .monitor import compute_metrics, pin_metrics, plot_metrics # noqa
24+
=======
25+
import importlib
26+
from .ptype import *
27+
from .vetiver_model import *
28+
from .server import *
29+
from .mock import *
30+
from .pin_read_write import *
31+
from .attach_pkgs import *
32+
from .meta import *
33+
from .write_docker import *
34+
from .write_fastapi import *
35+
from .handlers._interface import *
36+
from .handlers.sklearn_vt import *
37+
from .handlers.pytorch_vt import *
38+
from .monitor import compute_metrics, plot_metrics, pin_metrics
39+
>>>>>>> f936189 (pin_metrics to facet)

vetiver/monitor.py

Lines changed: 100 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,118 @@
1+
import datetime
12
import pins
3+
from pins import PinsError
24
import matplotlib.pyplot as plt
35
import seaborn as sns
4-
import numpy as np
56
import pandas as pd
67

78
def compute_metrics(data,
89
date_var,
910
period,
1011
metric_set,
1112
truth,
12-
pred):
13-
13+
estimate):
14+
"""
15+
Compute metrics for given time period
16+
17+
Parameters
18+
----------
19+
data : DataFrame
20+
Pandas dataframe
21+
date_var:
22+
Column in `data` containing dates
23+
period:
24+
Defining period to group by
25+
https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases
26+
metric_set: list
27+
List of metrics to compute
28+
https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
29+
truth:
30+
Column identifier for true results
31+
estimate:
32+
Column identifier for predicted results
33+
34+
"""
1435
output = pd.DataFrame()
15-
rol = data.set_index(date_var)
16-
rol = rol.rolling(window=period)
36+
dates_sorted = data.set_index(date_var).sort_index()
37+
rol = dates_sorted.set_index(truth).rolling(window=period)
1738

1839
for metric in metric_set:
19-
rol.apply(
40+
rol[metric.__name__] =rol.apply(
2041
_compute,
2142
raw=False,
22-
kwargs = {'df':output,
23-
'truth': truth,
24-
'pred': pred,
43+
kwargs = {
2544
'metric': metric}
2645
)
2746

47+
rol.set_index(date_var)
48+
2849
return output
2950

3051
def _compute(window,
31-
df,
32-
truth,
33-
pred,
3452
metric):
35-
output = metric(y_true=window,
36-
# TO DO: change y pred to correct piece
37-
y_pred=np.random.randint(0,1000,size=(len(window),1)))
38-
df.loc[window.index.max(), metric.__name__] = output
39-
return output
40-
41-
42-
# def pin_metrics(board,
43-
# df_metrics,
44-
# metrics_pin_name,
45-
# #.index = .index,
46-
# overwrite = True):
47-
48-
# # date_var <- quo_name(enquo(date_var)) #enquo?
4953

50-
# new_metrics = df_metrics.sort()
51-
52-
# new_dates = df_metrics.index.unique()
54+
output = metric(y_true = window.index,
55+
y_pred=window)
56+
return output
5357

54-
# old_metrics = pins.pin_read(board, metrics_pin_name)
55-
# overlapping_dates = old_metrics.index in new_dates
5658

57-
# if overwrite is True:
58-
# old_metrics = old_metrics not in overlapping_dates
59-
# else:
60-
# if overlapping_dates:
61-
# raise ValueError(f"The new metrics overlap with dates \
62-
# already stored in {repr(metrics_pin_name)} \
63-
# Check the aggregated dates or use `overwrite = True`"
64-
# )
59+
def pin_metrics(board,
60+
df_metrics,
61+
metrics_pin_name,
62+
overwrite = False):
63+
"""
64+
Update an existing pin storing model metrics over time
65+
66+
Parameters
67+
----------
68+
board :
69+
Pins board
70+
df_metrics: pd.DataFrame
71+
Dataframe of metrics over time, such as created by `vetiver_compute_metrics()`
72+
metrics_pin_name:
73+
Pin name for where the metrics are stored
74+
overwrite: bool
75+
If TRUE (the default), overwrite any metrics for
76+
dates that exist both in the existing pin and
77+
new metrics with the new values. If FALSE, error
78+
when the new metrics contain overlapping dates with
79+
the existing pin.
80+
"""
81+
date_types = (datetime.date, datetime.time, datetime.datetime)
82+
if not isinstance(df_metrics.index, date_types):
83+
try:
84+
df_metrics = df_metrics.index.astype('datetime')
85+
except TypeError:
86+
raise TypeError("Index must be a date type")
87+
88+
new_metrics = df_metrics.sort_index()
89+
90+
new_dates = df_metrics.index.unique()
91+
92+
try:
93+
old_metrics = board.pin_read(metrics_pin_name)
94+
except PinsError:
95+
board.pin_write(metrics_pin_name)
96+
97+
overlapping_dates = old_metrics.index in new_dates
98+
99+
if overwrite is True:
100+
old_metrics = old_metrics not in overlapping_dates
101+
else:
102+
if overlapping_dates:
103+
raise ValueError(f"The new metrics overlap with dates \
104+
already stored in {repr(metrics_pin_name)} \
105+
Check the aggregated dates or use `overwrite = True`"
106+
)
65107

66-
# new_metrics = old_metrics + df_metrics
67-
# new_metrics <- vec_slice(
68-
# new_metrics,
69-
# vctrs::vec_order(new_metrics.index)
70-
# )
108+
new_metrics = old_metrics + df_metrics
109+
new_metrics = new_metrics.sort_index()
71110

72-
# pins.pin_write(board, new_metrics, metrics_pin_name)
111+
pins.pin_write(board, new_metrics, metrics_pin_name)
73112

74113

75-
def vetiver_plot_metrics(df_metrics,
76-
date_var,
77-
metric):
114+
def plot_metrics(df_metrics,
115+
date_var, metric):
78116

79117
ncols = 1
80118
nrows = len(metric)
@@ -83,11 +121,24 @@ def vetiver_plot_metrics(df_metrics,
83121
nrows=nrows,
84122
ncols=ncols
85123
)
124+
"""
125+
Plot metrics over a given time period
126+
127+
Parameters
128+
----------
129+
df_metrics : DataFrame
130+
Pandas dataframe of metrics over time, such as created by `compute_metircs()`
131+
date_var:
132+
Column in `data` containing dates
133+
metric_set: list
134+
List of metrics to compute
135+
136+
"""
86137

87138
# loop for plotting each column
88139
for i, col in enumerate(df_metrics.columns):
89140
sns.lineplot(
90-
x=df_metrics[date_var],
141+
x=date_var,
91142
y=df_metrics[col],
92143
ax=axes[i],
93144
color='royalblue'

0 commit comments

Comments
 (0)