Skip to content

Commit 6c9c59a

Browse files
committed
plotting to plotly
1 parent 7bb9e5c commit 6c9c59a

3 files changed

Lines changed: 123 additions & 119 deletions

File tree

setup.cfg

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ install_requires =
3434
requests
3535
pins
3636
rsconnect-python
37-
matplotlib
38-
seaborn
37+
plotly
3938

4039
[options.extras_require]
4140
dev =

vetiver/__init__.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
__author__ = "Isabel Zimmerman <isabel.zimmerman@rstudio.com>"
55
__all__ = []
66

7-
<<<<<<< HEAD
87
import importlib # noqa
98
from .ptype import * # noqa
109
from .vetiver_model import VetiverModel # noqa
@@ -21,19 +20,3 @@
2120
from .handlers.torch import TorchHandler # noqa
2221
from .rsconnect import deploy_rsconnect # noqa
2322
from .monitor import compute_metrics, pin_metrics, plot_metrics # noqa
24-
=======
25-
import importlib
26-
from .ptype import *
27-
from .vetiver_model import *
28-
from .server import *
29-
from .mock import *
30-
from .pin_read_write import *
31-
from .attach_pkgs import *
32-
from .meta import *
33-
from .write_docker import *
34-
from .write_fastapi import *
35-
from .handlers._interface import *
36-
from .handlers.sklearn_vt import *
37-
from .handlers.pytorch_vt import *
38-
from .monitor import compute_metrics, plot_metrics, pin_metrics
39-
>>>>>>> f936189 (pin_metrics to facet)

vetiver/monitor.py

Lines changed: 122 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
import datetime
22
import pins
33
from pins.errors import PinsError
4-
import matplotlib.pyplot as plt
5-
import seaborn as sns
4+
import plotly.express as px
65
import pandas as pd
6+
from datetime import datetime, timedelta
77

88

9-
def compute_metrics(data, date_var, period, metric_set, truth, estimate):
9+
def compute_metrics(
10+
data: pd.DataFrame,
11+
date_var: str,
12+
period: timedelta,
13+
metric_set: list,
14+
truth: str,
15+
estimate: str,
16+
):
1017
"""
1118
Compute metrics for given time period
1219
@@ -23,121 +30,136 @@ def compute_metrics(data, date_var, period, metric_set, truth, estimate):
2330
List of metrics to compute
2431
https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
2532
truth:
26-
Column identifier for true results
33+
Column name for true results
2734
estimate:
28-
Column identifier for predicted results
35+
Column name for predicted results
2936
30-
"""
31-
output = pd.DataFrame()
32-
data = data[[date_var, truth, estimate]]
33-
34-
dates_sorted = data.set_index(date_var).sort_index()
35-
rol = dates_sorted.set_index(truth).rolling(window=period)
36-
37-
for metric in metric_set:
38-
output[metric.__name__] = rol.apply(_compute,
39-
raw=False,
40-
kwargs={"metric": metric})
41-
42-
#BAD
43-
dates_sorted = dates_sorted.reset_index()
44-
output = pd.concat([dates_sorted[date_var],
45-
output.reset_index().drop(columns=truth)],
46-
axis = 1).set_index(date_var)
47-
48-
return output
49-
50-
51-
def _compute(window, metric):
37+
Example
38+
-------
39+
rng = pd.date_range("1/1/2012", periods=100, freq="S")
40+
new = dict(x=range(len(ts)), y = range(len(ts)))
41+
df = pd.DataFrame(new, index = rng).reset_index(inplace=True)
42+
metric_set = [metrics.mean_squared_error, metrics.mean_absolute_error]
43+
compute_metrics(df, "index", td, metric_set=metric_set, truth="x", estimate="y")
5244
53-
output = metric(y_true=window.index, y_pred=window)
54-
return output
55-
56-
57-
def pin_metrics(board, df_metrics, metrics_pin_name, overwrite=False):
5845
"""
59-
Update an existing pin storing model metrics over time
6046

61-
Parameters
62-
----------
63-
board :
64-
Pins board
65-
df_metrics: pd.DataFrame
66-
Dataframe of metrics over time, such as created by `vetiver_compute_metrics()`
67-
metrics_pin_name:
68-
Pin name for where the metrics are stored
69-
overwrite: bool
70-
If TRUE (the default), overwrite any metrics for
71-
dates that exist both in the existing pin and
72-
new metrics with the new values. If FALSE, error
73-
when the new metrics contain overlapping dates with
74-
the existing pin.
75-
"""
76-
date_types = (datetime.date, datetime.time, datetime.datetime)
77-
if not isinstance(df_metrics.index, date_types):
78-
try:
79-
df_metrics = df_metrics.index.astype("datetime")
80-
except TypeError:
81-
raise TypeError(f"Index of {df_metrics} must be a date type")
82-
83-
new_metrics = df_metrics.sort_index()
84-
85-
new_dates = df_metrics.index.unique()
47+
df = data[[truth, estimate, date_var]].set_index(date_var).sort_index()
48+
lst = [_ for _ in _rolling_df(df=df, td=period)]
8649

87-
try:
88-
old_metrics = board.pin_read(metrics_pin_name)
89-
except PinsError:
90-
board.pin_write(metrics_pin_name)
50+
rows = []
51+
for i in lst:
52+
for m in metric_set:
53+
rows = rows + [
54+
{
55+
"index": i.index[0],
56+
"n": len(i),
57+
"metric": m.__qualname__,
58+
"estimate": m(y_pred=i[truth], y_true=i[estimate]),
59+
}
60+
]
9161

92-
overlapping_dates = old_metrics.index in new_dates
62+
outdf = pd.DataFrame.from_dict(rows)
9363

94-
if overwrite is True:
95-
old_metrics = old_metrics not in overlapping_dates
96-
else:
97-
if overlapping_dates:
98-
raise ValueError(
99-
f"The new metrics overlap with dates \
100-
already stored in {repr(metrics_pin_name)} \
101-
Check the aggregated dates or use `overwrite = True`"
102-
)
64+
return outdf
10365

104-
new_metrics = old_metrics + df_metrics
105-
new_metrics = new_metrics.sort_index()
10666

107-
pins.pin_write(board, new_metrics, metrics_pin_name)
67+
def _rolling_df(df: pd.DataFrame, td: timedelta):
68+
first = df.index[0]
69+
last = df.index[-1]
10870

71+
while first < last:
72+
stop = first + td
73+
boolidx = (first <= df.index) & (df.index < stop)
74+
yield df[boolidx].copy()
75+
first = stop
10976

110-
def plot_metrics(df_metrics, metric, date_var = None):
11177

112-
ncols = 1
113-
nrows = len(metric)
114-
115-
fig, axes = plt.subplots(nrows=nrows, ncols=ncols)
78+
def pin_metrics(board, df_metrics, metrics_pin_name, overwrite=False):
79+
pass
80+
81+
82+
# """
83+
# Update an existing pin storing model metrics over time
84+
85+
# Parameters
86+
# ----------
87+
# board :
88+
# Pins board
89+
# df_metrics: pd.DataFrame
90+
# Dataframe of metrics over time, such as created by `vetiver_compute_metrics()`
91+
# metrics_pin_name:
92+
# Pin name for where the metrics are stored
93+
# overwrite: bool
94+
# If TRUE (the default), overwrite any metrics for
95+
# dates that exist both in the existing pin and
96+
# new metrics with the new values. If FALSE, error
97+
# when the new metrics contain overlapping dates with
98+
# the existing pin.
99+
# """
100+
# date_types = (datetime.date, datetime.time, datetime.datetime)
101+
# if not isinstance(df_metrics.index, date_types):
102+
# try:
103+
# df_metrics = df_metrics.index.astype("datetime")
104+
# except TypeError:
105+
# raise TypeError(f"Index of {df_metrics} must be a date type")
106+
107+
# new_metrics = df_metrics.sort_index()
108+
109+
# new_dates = df_metrics.index.unique()
110+
111+
# try:
112+
# old_metrics = board.pin_read(metrics_pin_name)
113+
# except PinsError:
114+
# board.pin_write(metrics_pin_name)
115+
116+
# overlapping_dates = old_metrics.index in new_dates
117+
118+
# if overwrite is True:
119+
# old_metrics = old_metrics not in overlapping_dates
120+
# else:
121+
# if overlapping_dates:
122+
# raise ValueError(
123+
# f"The new metrics overlap with dates \
124+
# already stored in {repr(metrics_pin_name)} \
125+
# Check the aggregated dates or use `overwrite = True`"
126+
# )
127+
128+
# new_metrics = old_metrics + df_metrics
129+
# new_metrics = new_metrics.sort_index()
130+
131+
# pins.pin_write(board, new_metrics, metrics_pin_name)
132+
133+
134+
def plot_metrics(
135+
df_metrics, date="index", estimate="estimate", metric="metric", n="n", **kw
136+
):
116137
"""
117138
Plot metrics over a given time period
118139
119140
Parameters
120141
----------
121142
df_metrics : DataFrame
122143
Pandas dataframe of metrics over time, such as created by `compute_metircs()`
123-
date_var:
124-
Column in `data` containing dates
125-
metric_set: list
126-
List of metrics to compute
127-
144+
date: str
145+
Column in `df_metrics` containing dates
146+
estimate: str
147+
Column in `df_metrics` containing metric output
148+
metric: str
149+
Column in `df_metrics` containing metric name
150+
n: str
151+
Column in `df_metrics` containing number of observations
128152
"""
129-
# date_types = (datetime.date, datetime.time, datetime.datetime)
130-
131-
# if not isinstance(df_metrics.index, date_types):
132-
# try:
133-
# df_metrics = df_metrics.index.astype("datetime")
134-
# except:
135-
# raise TypeError("need datetime")
136-
137-
# loop for plotting each column
138-
for i, col in enumerate(df_metrics.columns):
139-
sns.lineplot(
140-
x=df_metrics.index, y=df_metrics[col], ax=axes[i], color="royalblue"
141-
).set_title(col)
142-
143-
fig.tight_layout()
153+
154+
fig = px.line(
155+
df_metrics,
156+
x=date,
157+
y=estimate,
158+
color=metric,
159+
facet_row=metric,
160+
markers=n,
161+
**kw,
162+
)
163+
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
164+
fig.update_layout(showlegend=False)
165+
fig.show()

0 commit comments

Comments
 (0)