Skip to content

Commit 5639bf6

Browse files
authored
Merge pull request #191 from rstudio/wb-update
2 parents be4f6b6 + d5dce53 commit 5639bf6

3 files changed

Lines changed: 57 additions & 13 deletions

File tree

vetiver/server.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
1-
from typing import Callable, List, Union
2-
from urllib.parse import urljoin
3-
41
import re
52
import httpx
63
import json
7-
import pandas as pd
84
import requests
95
import uvicorn
6+
import logging
7+
import pandas as pd
108
from fastapi import FastAPI, Request, testclient
119
from fastapi.exceptions import RequestValidationError
1210
from fastapi.openapi.utils import get_openapi
13-
from fastapi.responses import HTMLResponse, RedirectResponse
14-
from fastapi.responses import PlainTextResponse
11+
from fastapi.responses import HTMLResponse, RedirectResponse, PlainTextResponse
1512
from textwrap import dedent
1613
from warnings import warn
14+
from urllib.parse import urljoin
15+
from typing import Callable, List, Union
1716

18-
from .utils import _jupyter_nb
17+
from .utils import _jupyter_nb, get_workbench_path
1918
from .vetiver_model import VetiverModel
2019
from .meta import VetiverMeta
2120
from .helpers import api_data_to_frame, response_to_frame
@@ -71,6 +70,7 @@ def __init__(
7170
self.model = model
7271
self.app_factory = app_factory
7372
self.app = app_factory()
73+
self.workbench_path = None
7474

7575
if "check_ptype" in kwargs:
7676
check_prototype = kwargs.pop("check_ptype")
@@ -93,6 +93,14 @@ def _init_app(self):
9393
app = self.app
9494
app.openapi = self._custom_openapi
9595

96+
@app.on_event("startup")
97+
async def startup_event():
98+
logger = logging.getLogger("uvicorn.error")
99+
if self.workbench_path:
100+
logger.info(f"VetiverAPI starting at {self.workbench_path}")
101+
else:
102+
logger.info("VetiverAPI starting...")
103+
96104
@app.get("/", include_in_schema=False)
97105
def docs_redirect():
98106

@@ -261,7 +269,14 @@ def run(self, port: int = 8000, host: str = "127.0.0.1", **kw):
261269
>>> v_api.run() # doctest: +SKIP
262270
"""
263271
_jupyter_nb()
264-
uvicorn.run(self.app, port=port, host=host, **kw)
272+
self.workbench_path = get_workbench_path(port)
273+
274+
if self.workbench_path:
275+
uvicorn.run(
276+
self.app, port=port, host=host, root_path=self.workbench_path, **kw
277+
)
278+
else:
279+
uvicorn.run(self.app, port=port, host=host, **kw)
265280

266281
def _custom_openapi(self):
267282
import vetiver

vetiver/tests/test_xgboost.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,17 @@
55
from vetiver.data import mtcars # noqa
66
from vetiver.handlers.xgboost import XGBoostHandler # noqa
77
import numpy as np # noqa
8+
import sys # noqa
89
from fastapi.testclient import TestClient # noqa
910

1011
import vetiver # noqa
1112

13+
# hack since xgboost 2.0 dropped 3.7 support
14+
if sys.version_info[0] == 3 and sys.version_info[1] < 8:
15+
PREDICT_VALUE = 21.064373016357422
16+
else:
17+
PREDICT_VALUE = 19.963224411010742
18+
1219

1320
@pytest.fixture
1421
def xgb_model():
@@ -57,7 +64,7 @@ def test_vetiver_build(vetiver_client):
5764

5865
response = vetiver.predict(endpoint=vetiver_client, data=data)
5966

60-
assert response.iloc[0, 0] == 21.064373016357422
67+
assert response.iloc[0, 0] == PREDICT_VALUE
6168
assert len(response) == 1
6269

6370

@@ -66,7 +73,7 @@ def test_batch(vetiver_client):
6673

6774
response = vetiver.predict(endpoint=vetiver_client, data=data)
6875

69-
assert response.iloc[0, 0] == 21.064373016357422
76+
assert response.iloc[0, 0] == PREDICT_VALUE
7077
assert len(response) == 3
7178

7279

@@ -75,7 +82,7 @@ def test_no_ptype(vetiver_client_check_ptype_false):
7582

7683
response = vetiver.predict(endpoint=vetiver_client_check_ptype_false, data=data)
7784

78-
assert response.iloc[0, 0] == 21.064373016357422
85+
assert response.iloc[0, 0] == PREDICT_VALUE
7986
assert len(response) == 1
8087

8188

vetiver/utils.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import nest_asyncio
22
import warnings
33
import sys
4+
import os
5+
import subprocess
46
from types import SimpleNamespace
57

68
no_notebook = False
@@ -14,8 +16,8 @@ def _jupyter_nb():
1416

1517
if not no_notebook:
1618
warnings.warn(
17-
"WARNING: Jupyter Notebooks are not considered stable environments "
18-
"for production code"
19+
"You may be running from a notebook environment. Jupyter Notebooks are "
20+
"not considered stable environments for production code"
1921
)
2022
nest_asyncio.apply()
2123
else:
@@ -31,3 +33,23 @@ def inform(log, msg):
3133

3234
if not modelcard_options.quiet:
3335
print(msg, file=sys.stderr)
36+
37+
38+
def get_workbench_path(port):
39+
# check to see if in Posit Workbench, pulled from FastAPI section of user guide
40+
# https://docs.posit.co/ide/server-pro/user/vs-code/guide/proxying-web-servers.html#running-fastapi-with-uvicorn # noqa
41+
42+
if "RS_SERVER_URL" in os.environ and os.environ["RS_SERVER_URL"]:
43+
path = (
44+
subprocess.run(
45+
f"echo $(/usr/lib/rstudio-server/bin/rserver-url -l {port})",
46+
stdout=subprocess.PIPE,
47+
shell=True,
48+
)
49+
.stdout.decode()
50+
.strip()
51+
)
52+
# subprocess is run, new URL given
53+
return path
54+
else:
55+
return None

0 commit comments

Comments
 (0)