Skip to content

Commit 927e80d

Browse files
authored
Merge branch 'main' into handlers-update
2 parents 792fb4f + f37b86b commit 927e80d

12 files changed

Lines changed: 217 additions & 103 deletions

File tree

MANIFEST.in

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
include vetiver/data/*.csv

docs/source/conf.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import sys
3-
from vetiver import __version__
3+
import re
4+
from vetiver import __version__ as version
45

56
sys.path.insert(0, os.path.abspath("."))
67

@@ -12,8 +13,14 @@
1213
author = "Isabel Zimmerman"
1314

1415
# The full version, including alpha/beta/rc tags
15-
release = __version__
16-
16+
# Do not use the 0.0 version created by setuptools_scm when the package
17+
# is cloned too shallow to pickup
18+
p = re.compile(r"^0\.0\.post\d+\+g")
19+
if p.match(version):
20+
commit = p.sub("", version)
21+
version = f"Commit: {commit}"
22+
23+
release = version
1724

1825
# -- General configuration ---------------------------------------------------
1926

pyproject.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
[build-system]
2-
requires = ["setuptools>=45", "wheel"]
2+
requires = [
3+
"setuptools>=59",
4+
"setuptools_scm[toml]>=6.4",
5+
"wheel"
6+
]
37
build-backend = "setuptools.build_meta"
48

59
[tool.pytest.ini_options]
@@ -9,3 +13,7 @@ doctest_optionflags = "NORMALIZE_WHITESPACE"
913
markers = [
1014
"rsc_test: tests for rstudio connect",
1115
]
16+
17+
[tool.setuptools_scm]
18+
fallback_version = "999"
19+
version_scheme = 'post-release'

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ install_requires =
3535
pins
3636
rsconnect-python
3737
plotly
38+
importlib-metadata>=4.4 # NOTE: Remove when python_requires>=3.8
3839

3940
[options.extras_require]
4041
dev =

setup.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,4 @@
1-
import io
2-
import os
3-
import re
4-
5-
from setuptools import find_packages
61
from setuptools import setup
72

83

9-
def read(filename):
10-
filename = os.path.join(os.path.dirname(__file__), filename)
11-
text_type = type("")
12-
with io.open(filename, mode="r", encoding="utf-8") as fd:
13-
return re.sub(text_type(r":[a-z]+:`~?(.*?)`"), text_type(r"``\1``"), fd.read())
14-
15-
16-
if __name__ == "__main__":
17-
setup()
4+
setup()

vetiver/__init__.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
"""vetiver - Python parallel to R vetiver package"""
2+
# Change to import.metadata when minimum python>=3.8
3+
from importlib_metadata import version as _version
24

3-
__version__ = "0.1.5"
4-
__author__ = "Isabel Zimmerman <isabel.zimmerman@rstudio.com>"
5-
__all__ = []
6-
7-
import importlib # noqa
85
from .ptype import * # noqa
96
from .vetiver_model import VetiverModel # noqa
107
from .server import VetiverAPI, vetiver_endpoint, predict # noqa
@@ -17,5 +14,10 @@
1714
from .handlers.base import BaseHandler, create_handler, InvalidModelError # noqa
1815
from .handlers.sklearn import SKLearnHandler # noqa
1916
from .handlers.torch import TorchHandler # noqa
20-
from .rsconnect import deploy_rsconnect # noqa
21-
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa
17+
from .rsconnect import deploy_rsconnect # noqa
18+
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa
19+
20+
__author__ = "Isabel Zimmerman <isabel.zimmerman@rstudio.com>"
21+
__all__ = []
22+
__version__ = _version("vetiver")
23+
del _version

vetiver/monitor.py

Lines changed: 62 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
import datetime
2-
import pins
3-
from pins.errors import PinsError
41
import plotly.express as px
52
import pandas as pd
6-
from datetime import datetime, timedelta
3+
from datetime import timedelta
74

85

96
def compute_metrics(
@@ -75,60 +72,74 @@ def _rolling_df(df: pd.DataFrame, td: timedelta):
7572
first = stop
7673

7774

78-
def pin_metrics(board, df_metrics, metrics_pin_name, overwrite=False):
79-
pass
75+
def pin_metrics(
76+
board,
77+
df_metrics: pd.DataFrame,
78+
metrics_pin_name: str,
79+
pin_type: "str | None" = None,
80+
index_name: str = "index",
81+
overwrite: bool = False,
82+
) -> pd.DataFrame:
83+
"""
84+
Update an existing pin storing model metrics over time
85+
86+
Parameters
87+
----------
88+
board :
89+
Pins board
90+
df_metrics: pd.DataFrame
91+
Dataframe of metrics over time, such as created by `vetiver_compute_metrics()`
92+
metrics_pin_name:
93+
Pin name for where the metrics are stored
94+
index_name:
95+
The column in df_metrics containing the aggregated dates or datetimes.
96+
Note that this defaults to a column named "index".
97+
overwrite: bool
98+
If TRUE (the default), overwrite any metrics for
99+
dates that exist both in the existing pin and
100+
new metrics with the new values. If FALSE, error
101+
when the new metrics contain overlapping dates with
102+
the existing pin.
103+
"""
80104

105+
old_metrics_raw = board.pin_read(metrics_pin_name)
81106

82-
# """
83-
# Update an existing pin storing model metrics over time
107+
# need to coerce date index to a datetime, since pandas does not infer
108+
# date columns from CSV (but note that formats like arrow do)
109+
old_metrics = old_metrics_raw.copy()
110+
old_metrics[index_name] = pd.to_datetime(old_metrics[index_name])
84111

85-
# Parameters
86-
# ----------
87-
# board :
88-
# Pins board
89-
# df_metrics: pd.DataFrame
90-
# Dataframe of metrics over time, such as created by `vetiver_compute_metrics()`
91-
# metrics_pin_name:
92-
# Pin name for where the metrics are stored
93-
# overwrite: bool
94-
# If TRUE (the default), overwrite any metrics for
95-
# dates that exist both in the existing pin and
96-
# new metrics with the new values. If FALSE, error
97-
# when the new metrics contain overlapping dates with
98-
# the existing pin.
99-
# """
100-
# date_types = (datetime.date, datetime.time, datetime.datetime)
101-
# if not isinstance(df_metrics.index, date_types):
102-
# try:
103-
# df_metrics = df_metrics.index.astype("datetime")
104-
# except TypeError:
105-
# raise TypeError(f"Index of {df_metrics} must be a date type")
112+
# handle overlapping dates ----
113+
dt_new = pd.to_datetime(df_metrics[index_name])
114+
dt_old = old_metrics[index_name]
106115

107-
# new_metrics = df_metrics.sort_index()
116+
indx_old_overlap = dt_old.isin(dt_new)
108117

109-
# new_dates = df_metrics.index.unique()
118+
if overwrite:
119+
# get only rows specific to old metrics, so when we concat below
120+
# it effectively is an upsert
121+
old_metrics = old_metrics.loc[~indx_old_overlap, :]
110122

111-
# try:
112-
# old_metrics = board.pin_read(metrics_pin_name)
113-
# except PinsError:
114-
# board.pin_write(metrics_pin_name)
123+
elif not overwrite and indx_old_overlap.any():
124+
raise ValueError(
125+
f"The new metrics overlap with dates already stored in {metrics_pin_name}."
126+
" Check the aggregated dates or use `overwrite=True`."
127+
)
115128

116-
# overlapping_dates = old_metrics.index in new_dates
129+
# update and pin ----
130+
combined_metrics = pd.concat([old_metrics, df_metrics], ignore_index=True)
131+
sorted_metrics = combined_metrics.sort_values(index_name)
117132

118-
# if overwrite is True:
119-
# old_metrics = old_metrics not in overlapping_dates
120-
# else:
121-
# if overlapping_dates:
122-
# raise ValueError(
123-
# f"The new metrics overlap with dates \
124-
# already stored in {repr(metrics_pin_name)} \
125-
# Check the aggregated dates or use `overwrite = True`"
126-
# )
133+
if pin_type is None:
134+
meta = board.pin_meta(metrics_pin_name)
127135

128-
# new_metrics = old_metrics + df_metrics
129-
# new_metrics = new_metrics.sort_index()
136+
final_pin_type = meta.type
137+
else:
138+
final_pin_type = pin_type
130139

131-
# pins.pin_write(board, new_metrics, metrics_pin_name)
140+
board.pin_write(sorted_metrics, metrics_pin_name, type=final_pin_type)
141+
142+
return sorted_metrics
132143

133144

134145
def plot_metrics(
@@ -157,9 +168,11 @@ def plot_metrics(
157168
y=estimate,
158169
color=metric,
159170
facet_row=metric,
160-
markers=n,
171+
markers=dict(size=n),
172+
hover_data={"n": ':'},
161173
**kw,
162174
)
175+
163176
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
164177
fig.update_layout(showlegend=False)
165178

vetiver/server.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
import uvicorn
77
import requests
88
import pandas as pd
9-
import numpy as np
109
from typing import Callable, Union, List
1110

12-
from . import __version__
1311
from .vetiver_model import VetiverModel
1412
from .utils import _jupyter_nb
1513

@@ -65,7 +63,7 @@ async def ping():
6563

6664
if self.check_ptype is True:
6765

68-
@app.post("/predict/")
66+
@app.post("/predict")
6967
async def prediction(
7068
input_data: Union[self.model.ptype, List[self.model.ptype]]
7169
):
@@ -82,7 +80,7 @@ async def prediction(
8280

8381
elif self.check_ptype is False:
8482

85-
@app.post("/predict/")
83+
@app.post("/predict")
8684
async def prediction(input_data: Request):
8785

8886
y = await input_data.json()
@@ -99,17 +97,25 @@ async def rapidoc():
9997
<!doctype html>
10098
<html>
10199
<head>
102-
<meta name="viewport" content="width=device-width,minimum-scale=1,initial-scale=1,user-scalable=yes">
100+
<meta name="viewport"
101+
content="width=device-width,minimum-scale=1,initial-scale=1,user-scalable=yes">
103102
<title>RapiDoc</title>
104-
<script type="module" src="https://unpkg.com/rapidoc@9.1.3/dist/rapidoc-min.js"></script>
103+
<script type="module"
104+
src="https://unpkg.com/rapidoc@9.3.3/dist/rapidoc-min.js"></script>
105105
</script></head>
106106
<body>
107107
<rapi-doc spec-url="{self.app.openapi_url[1:]}"
108-
id="thedoc" render-style="read" schema-style="tree"
109-
show-components="true" show-info="true" show-header="true"
108+
id="thedoc"
109+
render-style="read"
110+
schema-style="tree"
111+
show-components="true"
112+
show-info="true"
113+
show-header="true"
110114
allow-search="true"
111115
show-side-nav="false"
112-
allow-authentication="false" update-route="false" match-type="regex"
116+
allow-authentication="false"
117+
update-route="false"
118+
match-type="regex"
113119
theme="light"
114120
header-color="#F2C6AC"
115121
primary-color = "#8C2D2D">
@@ -143,15 +149,15 @@ def vetiver_post(
143149
"""
144150
if self.check_ptype is True:
145151

146-
@self.app.post("/" + endpoint_name + "/")
152+
@self.app.post("/" + endpoint_name)
147153
async def custom_endpoint(input_data: self.model.ptype):
148154
y = _prepare_data(input_data)
149155
new = endpoint_fx(pd.Series(y))
150156
return {endpoint_name: new.tolist()}
151157

152158
else:
153159

154-
@self.app.post("/" + endpoint_name + "/")
160+
@self.app.post("/" + endpoint_name)
155161
async def custom_endpoint(input_data: Request):
156162
y = await input_data.json()
157163
new = endpoint_fx(pd.Series(y))
@@ -173,11 +179,13 @@ def run(self, port: int = 8000, host: str = "127.0.0.1", **kw):
173179
uvicorn.run(self.app, port=port, host=host, **kw)
174180

175181
def _custom_openapi(self):
182+
import vetiver
183+
176184
if self.app.openapi_schema:
177185
return self.app.openapi_schema
178186
openapi_schema = get_openapi(
179187
title=self.model.model_name + " model API",
180-
version=__version__,
188+
version=vetiver.__version__,
181189
description=self.model.description,
182190
routes=self.app.routes,
183191
servers=self.app.servers,
@@ -204,7 +212,7 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw):
204212
"""
205213
if isinstance(endpoint, testclient.TestClient):
206214
requester = endpoint
207-
endpoint = "/predict/"
215+
endpoint = "/predict"
208216
else:
209217
requester = requests
210218

vetiver/tests/test_add_endpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_endpoint_adds_ptype():
2828

2929
client = TestClient(app)
3030
data = {"B": 0, "C": 0, "D": 0}
31-
response = client.post("/sum/", json=data)
31+
response = client.post("/sum", json=data)
3232
assert response.status_code == 200, response.text
3333
assert response.json() == {"sum": 0}, response.json()
3434

@@ -38,6 +38,6 @@ def test_endpoint_adds_no_ptype():
3838

3939
client = TestClient(app)
4040
data = [0, 0, 0]
41-
response = client.post("/sum/", json=data)
41+
response = client.post("/sum", json=data)
4242
assert response.status_code == 200, response.text
4343
assert response.json() == {"sum": 0}, response.json()

0 commit comments

Comments
 (0)