Skip to content

Commit 8e91d13

Browse files
committed
snapshot tests for plotting
1 parent 7d7fbf4 commit 8e91d13

3 files changed

Lines changed: 17 additions & 3 deletions

File tree

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ install_requires =
4040
dev =
4141
pytest
4242
pytest-cov
43+
pytest-snapshot
4344
sphinx
4445
sphinx-autodoc-typehints
4546
sphinx-book-theme

vetiver/monitor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,5 @@ def plot_metrics(
162162
)
163163
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
164164
fig.update_layout(showlegend=False)
165-
fig.show()
165+
166+
return fig

vetiver/tests/test_monitor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
1+
from typing import List
12
from sklearn import metrics
23
from datetime import timedelta
34
import pandas as pd
45
import numpy
5-
66
import vetiver
77

88
rng = pd.date_range("1/1/2012", periods=10, freq="S")
99
new = dict(x=range(len(rng)), y=range(len(rng)))
1010
df = pd.DataFrame(new, index=rng)
11-
df.reset_index(inplace=True)
1211
td = timedelta(seconds=2)
1312
metric_set = [metrics.mean_squared_error, metrics.mean_absolute_error]
1413

14+
def test_rolling():
15+
m = [_ for _ in vetiver._rolling_df(df, td)]
16+
assert len(m) == 5
17+
assert len(m[0]) == 2
1518

1619
def test_compute():
20+
df.reset_index(inplace=True)
1721
m = vetiver.compute_metrics(
1822
df, "index", td, metric_set=metric_set, truth="x", estimate="y"
1923
)
@@ -23,3 +27,11 @@ def test_compute():
2327
m.metric.unique(),
2428
numpy.array(["mean_squared_error", "mean_absolute_error"], dtype=object),
2529
)
30+
31+
def test_monitor(snapshot):
32+
snapshot.snapshot_dir = './snapshots'
33+
m = vetiver.compute_metrics(
34+
df, "index", td, metric_set=metric_set, truth="x", estimate="y"
35+
)
36+
vetiver.plot_metrics(m)
37+
snapshot.assert_match(m.to_json(), 'test_monitor.json')

0 commit comments

Comments
 (0)