Skip to content

Commit 85fef87

Browse files
authored
Merge pull request #76 from isabelizimm/dev-model-monitoring
model monitoring
2 parents 94ac2be + e9a9405 commit 85fef87

6 files changed

Lines changed: 218 additions & 1 deletion

File tree

docs/source/index.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@ Deploy
5353
~write_app
5454
~write_docker
5555

56+
Monitor
57+
==================
58+
59+
.. autosummary::
60+
:toctree: reference/
61+
:caption: Monitor
62+
63+
~compute_metrics
64+
~pin_metrics
65+
~plot_metrics
66+
5667
Advanced Usage
5768
==================
5869
.. toctree::

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,13 @@ install_requires =
3434
requests
3535
pins
3636
rsconnect-python
37+
plotly
3738

3839
[options.extras_require]
3940
dev =
4041
pytest
4142
pytest-cov
43+
pytest-snapshot
4244
sphinx
4345
sphinx-autodoc-typehints
4446
sphinx-book-theme

vetiver/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
from .handlers.base import VetiverHandler # noqa
1919
from .handlers.sklearn import SKLearnHandler # noqa
2020
from .handlers.torch import TorchHandler # noqa
21-
from .rsconnect import deploy_rsconnect
21+
from .rsconnect import deploy_rsconnect # noqa
22+
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa

vetiver/monitor.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import datetime
2+
import pins
3+
from pins.errors import PinsError
4+
import plotly.express as px
5+
import pandas as pd
6+
from datetime import datetime, timedelta
7+
8+
9+
def compute_metrics(
10+
data: pd.DataFrame,
11+
date_var: str,
12+
period: timedelta,
13+
metric_set: list,
14+
truth: str,
15+
estimate: str,
16+
) -> pd.DataFrame:
17+
"""
18+
Compute metrics for given time period
19+
20+
Parameters
21+
----------
22+
data : DataFrame
23+
Pandas dataframe
24+
date_var:
25+
Column in `data` containing dates
26+
period: datetime.timedelta
27+
Defining period to group by
28+
metric_set: list
29+
List of metrics to compute, that have the parameters `y_true` and `y_pred`
30+
truth:
31+
Column name for true results
32+
estimate:
33+
Column name for predicted results
34+
35+
Example
36+
-------
37+
from sklearn import metrics
38+
rng = pd.date_range("1/1/2012", periods=10, freq="S")
39+
new = dict(x=range(len(rng)), y = range(len(rng)))
40+
df = pd.DataFrame(new, index = rng).reset_index(inplace=True)
41+
td = timedelta(seconds = 2)
42+
metric_set = [sklearn.metrics.mean_squared_error, sklearn.metrics.mean_absolute_error]
43+
compute_metrics(df, "index", td, metric_set=metric_set, truth="x", estimate="y")
44+
45+
"""
46+
47+
df = data[[truth, estimate, date_var]].set_index(date_var).sort_index()
48+
lst = [_ for _ in _rolling_df(df=df, td=period)]
49+
50+
rows = []
51+
for i in lst:
52+
for m in metric_set:
53+
rows = rows + [
54+
{
55+
"index": i.index[0],
56+
"n": len(i),
57+
"metric": m.__qualname__,
58+
"estimate": m(y_pred=i[truth], y_true=i[estimate]),
59+
}
60+
]
61+
62+
outdf = pd.DataFrame.from_dict(rows)
63+
64+
return outdf
65+
66+
67+
def _rolling_df(df: pd.DataFrame, td: timedelta):
68+
first = df.index[0]
69+
last = df.index[-1]
70+
71+
while first < last:
72+
stop = first + td
73+
boolidx = (first <= df.index) & (df.index < stop)
74+
yield df[boolidx].copy()
75+
first = stop
76+
77+
78+
def pin_metrics(board, df_metrics, metrics_pin_name, overwrite=False):
79+
pass
80+
81+
82+
# """
83+
# Update an existing pin storing model metrics over time
84+
85+
# Parameters
86+
# ----------
87+
# board :
88+
# Pins board
89+
# df_metrics: pd.DataFrame
90+
# Dataframe of metrics over time, such as created by `vetiver_compute_metrics()`
91+
# metrics_pin_name:
92+
# Pin name for where the metrics are stored
93+
# overwrite: bool
94+
# If TRUE (the default), overwrite any metrics for
95+
# dates that exist both in the existing pin and
96+
# new metrics with the new values. If FALSE, error
97+
# when the new metrics contain overlapping dates with
98+
# the existing pin.
99+
# """
100+
# date_types = (datetime.date, datetime.time, datetime.datetime)
101+
# if not isinstance(df_metrics.index, date_types):
102+
# try:
103+
# df_metrics = df_metrics.index.astype("datetime")
104+
# except TypeError:
105+
# raise TypeError(f"Index of {df_metrics} must be a date type")
106+
107+
# new_metrics = df_metrics.sort_index()
108+
109+
# new_dates = df_metrics.index.unique()
110+
111+
# try:
112+
# old_metrics = board.pin_read(metrics_pin_name)
113+
# except PinsError:
114+
# board.pin_write(metrics_pin_name)
115+
116+
# overlapping_dates = old_metrics.index in new_dates
117+
118+
# if overwrite is True:
119+
# old_metrics = old_metrics not in overlapping_dates
120+
# else:
121+
# if overlapping_dates:
122+
# raise ValueError(
123+
# f"The new metrics overlap with dates \
124+
# already stored in {repr(metrics_pin_name)} \
125+
# Check the aggregated dates or use `overwrite = True`"
126+
# )
127+
128+
# new_metrics = old_metrics + df_metrics
129+
# new_metrics = new_metrics.sort_index()
130+
131+
# pins.pin_write(board, new_metrics, metrics_pin_name)
132+
133+
134+
def plot_metrics(
135+
df_metrics, date="index", estimate="estimate", metric="metric", n="n", **kw
136+
) -> px.line:
137+
"""
138+
Plot metrics over a given time period
139+
140+
Parameters
141+
----------
142+
df_metrics : DataFrame
143+
Pandas dataframe of metrics over time, such as created by `compute_metircs()`
144+
date: str
145+
Column in `df_metrics` containing dates
146+
estimate: str
147+
Column in `df_metrics` containing metric output
148+
metric: str
149+
Column in `df_metrics` containing metric name
150+
n: str
151+
Column in `df_metrics` containing number of observations
152+
"""
153+
154+
fig = px.line(
155+
df_metrics,
156+
x=date,
157+
y=estimate,
158+
color=metric,
159+
facet_row=metric,
160+
markers=n,
161+
**kw,
162+
)
163+
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
164+
fig.update_layout(showlegend=False)
165+
166+
return fig
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"index":{"0":1325376000000,"1":1325376000000,"2":1325376002000,"3":1325376002000,"4":1325376004000,"5":1325376004000,"6":1325376006000,"7":1325376006000,"8":1325376008000,"9":1325376008000},"n":{"0":2,"1":2,"2":2,"3":2,"4":2,"5":2,"6":2,"7":2,"8":2,"9":2},"metric":{"0":"mean_squared_error","1":"mean_absolute_error","2":"mean_squared_error","3":"mean_absolute_error","4":"mean_squared_error","5":"mean_absolute_error","6":"mean_squared_error","7":"mean_absolute_error","8":"mean_squared_error","9":"mean_absolute_error"},"estimate":{"0":0.0,"1":0.0,"2":0.0,"3":0.0,"4":0.0,"5":0.0,"6":0.0,"7":0.0,"8":0.0,"9":0.0}}

vetiver/tests/test_monitor.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from sklearn import metrics
2+
from datetime import timedelta
3+
import pandas as pd
4+
import numpy
5+
import vetiver
6+
7+
rng = pd.date_range("1/1/2012", periods=10, freq="S")
8+
new = dict(x=range(len(rng)), y=range(len(rng)))
9+
df = pd.DataFrame(new, index=rng)
10+
td = timedelta(seconds=2)
11+
metric_set = [metrics.mean_squared_error, metrics.mean_absolute_error]
12+
13+
def test_rolling():
14+
m = [_ for _ in vetiver._rolling_df(df, td)]
15+
assert len(m) == 5
16+
assert len(m[0]) == 2
17+
18+
def test_compute():
19+
df.reset_index(inplace=True)
20+
m = vetiver.compute_metrics(
21+
df, "index", td, metric_set=metric_set, truth="x", estimate="y"
22+
)
23+
assert isinstance(m, pd.DataFrame)
24+
assert m.shape == (10, 4)
25+
numpy.testing.assert_array_equal(
26+
m.metric.unique(),
27+
numpy.array(["mean_squared_error", "mean_absolute_error"], dtype=object),
28+
)
29+
30+
def test_monitor(snapshot):
31+
snapshot.snapshot_dir = './vetiver/tests/snapshots'
32+
m = vetiver.compute_metrics(
33+
df, "index", td, metric_set=metric_set, truth="x", estimate="y"
34+
)
35+
vetiver.plot_metrics(m)
36+
snapshot.assert_match(m.to_json(), 'test_monitor.json')

0 commit comments

Comments
 (0)