Skip to content

Commit 1a96b34

Browse files
authored
Merge pull request #65 from isabelizimm/dev-new-names
renaming functions
2 parents 8c8e5ed + 281c391 commit 1a96b34

14 files changed

Lines changed: 184 additions & 140 deletions

vetiver/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66

77
import importlib
88
from .ptype import *
9-
from .vetiver_model import *
10-
from .server import *
11-
from .mock import *
12-
from .pin_read_write import *
9+
from .vetiver_model import VetiverModel
10+
from .server import VetiverAPI, vetiver_endpoint
11+
from .mock import get_mock_data, get_mock_model
12+
from .pin_read_write import vetiver_pin_write
1313
from .attach_pkgs import *
1414
from .meta import *
15-
from .write_docker import *
16-
from .write_fastapi import *
15+
from .write_docker import write_docker
16+
from .write_fastapi import write_app
1717
from .handlers._interface import create_handler, InvalidModelError
1818
from .handlers.base import VetiverHandler
1919
from .handlers.sklearn import SKLearnHandler

vetiver/handlers/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from ..ptype import vetiver_create_ptype
2-
from ..meta import vetiver_meta
31
from abc import ABCMeta
42

3+
from ..ptype import vetiver_create_ptype
4+
from ..meta import _model_meta
5+
56
class VetiverHandler(metaclass=ABCMeta):
67
"""Base handler class for creating VetiverModel of different type.
78
@@ -30,7 +31,7 @@ def create_meta(
3031
required_pkgs: list = [],
3132
):
3233
"""Create metadata for sklearn model"""
33-
meta = vetiver_meta(user, version, url, required_pkgs)
34+
meta = _model_meta(user, version, url, required_pkgs)
3435

3536
return meta
3637

vetiver/handlers/sklearn.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import pandas as pd
22
import sklearn
33

4+
from ..meta import _model_meta
45
from .base import VetiverHandler
5-
from ..meta import vetiver_meta
6+
67

78
class SKLearnHandler(VetiverHandler):
89
"""Handler class for creating VetiverModels with sklearn.
@@ -12,14 +13,14 @@ class SKLearnHandler(VetiverHandler):
1213
model : sklearn.base.BaseEstimator
1314
a trained sklearn model
1415
"""
16+
1517
base_class = sklearn.base.BaseEstimator
16-
18+
1719
def __init__(self, model, ptype_data):
1820
super().__init__(model, ptype_data)
1921

2022
def describe(self):
21-
"""Create description for sklearn model
22-
"""
23+
"""Create description for sklearn model"""
2324
desc = f"Scikit-learn {self.model.__class__} model"
2425
return desc
2526

@@ -31,11 +32,10 @@ def construct_meta(
3132
):
3233
"""Create metadata for sklearn model"""
3334
required_pkgs = required_pkgs + ["scikit-learn"]
34-
meta = vetiver_meta(user, version, url, required_pkgs)
35+
meta = _model_meta(user, version, url, required_pkgs)
3536

3637
return meta
3738

38-
3939
def handler_predict(self, input_data, check_ptype):
4040
"""Generates method for /predict endpoint in VetiverAPI
4141
@@ -58,7 +58,7 @@ def handler_predict(self, input_data, check_ptype):
5858
if isinstance(input_data, pd.DataFrame):
5959
prediction = self.model.predict(input_data)
6060
else:
61-
prediction = self.model.predict([input_data])
61+
prediction = self.model.predict([input_data])
6262

6363
# do not check ptype
6464
else:

vetiver/handlers/torch.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

3+
from ..meta import _model_meta
34
from .base import VetiverHandler
4-
from ..meta import vetiver_meta
55

66
torch_exists = True
77
try:
@@ -18,13 +18,14 @@ class TorchHandler(VetiverHandler):
1818
model : nn.Module
1919
a trained torch model
2020
"""
21+
2122
base_class = torch.nn.Module
23+
2224
def __init__(self, model, ptype_data):
2325
super().__init__(model, ptype_data)
2426

2527
def describe(self):
26-
"""Create description for torch model
27-
"""
28+
"""Create description for torch model"""
2829
desc = f"Pytorch model of type {type(self.model)}"
2930
return desc
3031

@@ -34,10 +35,9 @@ def create_meta(
3435
url: str = None,
3536
required_pkgs: list = [],
3637
):
37-
"""Create metadata for torch model
38-
"""
38+
"""Create metadata for torch model"""
3939
required_pkgs = required_pkgs + ["torch"]
40-
meta = vetiver_meta(user, version, url, required_pkgs)
40+
meta = _model_meta(user, version, url, required_pkgs)
4141

4242
return meta
4343

@@ -62,9 +62,9 @@ def handler_predict(self, input_data, check_ptype):
6262
if check_ptype == True:
6363
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
6464
prediction = self.model(torch.from_numpy(input_data))
65-
65+
6666
# do not check ptype
67-
else:
67+
else:
6868
input_data = torch.tensor(input_data)
6969
prediction = self.model(input_data)
7070

vetiver/meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
def vetiver_meta(user: dict = None, version: str = None,
1+
def _model_meta(user: dict = None, version: str = None,
22
url: str = None, required_pkgs: list = None):
33
"""Populate relevant metadata for VetiverModel
44

vetiver/pin_read_write.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
import pydantic
21
import pins
32
import warnings
43
import json
54

65
from .vetiver_model import VetiverModel
7-
from .meta import vetiver_meta
6+
from .meta import _model_meta
87
from .write_fastapi import _choose_version
98

10-
def vetiver_pin_write(board, model: VetiverModel, versioned: bool=True):
9+
10+
def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True):
1111
"""
1212
Pin a trained VetiverModel along with other model metadata.
13-
13+
1414
Parameters
1515
----------
1616
board:
@@ -21,16 +21,18 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool=True):
2121
Whether or not the pin should be versioned
2222
"""
2323
if not board.allow_pickle_read:
24-
raise NotImplementedError # must be pickle-able
24+
raise NotImplementedError # must be pickle-able
2525

2626
board.pin_write(
2727
model.model,
28-
name = model.model_name,
29-
type = "joblib",
30-
description = model.description,
31-
metadata = {"required_pkgs": model.metadata.get("required_pkgs"),
32-
"ptype": None if model.ptype == None else model.ptype().json()},
33-
versioned=versioned
28+
name=model.model_name,
29+
type="joblib",
30+
description=model.description,
31+
metadata={
32+
"required_pkgs": model.metadata.get("required_pkgs"),
33+
"ptype": None if model.ptype == None else model.ptype().json(),
34+
},
35+
versioned=versioned,
3436
)
3537

3638
# to do: Model card
@@ -46,15 +48,15 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool=True):
4648
def vetiver_pin_read(board, name: str, version: str = None) -> VetiverModel:
4749
"""
4850
Read pin and populate VetiverModel
49-
51+
5052
Parameters
5153
----------
5254
board:
5355
A pin board, created by `pins.board_folder()` or another `board_` function.
5456
name: string
5557
Pin name
5658
version: str
57-
Retrieve a specific version of a pin.
59+
Retrieve a specific version of a pin.
5860
5961
Returns
6062
--------
@@ -63,23 +65,22 @@ def vetiver_pin_read(board, name: str, version: str = None) -> VetiverModel:
6365
Notes
6466
-----
6567
If reading a board from RSConnect, the `board` argument must be in "username/modelname" format.
66-
68+
6769
"""
68-
version = version if version is not None else _choose_version(board.pin_versions(name))
69-
70-
model = board.pin_read(name, version)
71-
meta = board.pin_meta(name)
72-
73-
v = VetiverModel(model = model,
74-
model_name = name,
75-
description = meta.description,
76-
metadata = vetiver_meta(user = meta.user,
77-
version = version,
78-
url = meta.user.get("url"), # None all the time, besides Connect
79-
required_pkgs = meta.user.get("required_pkgs")
80-
),
81-
ptype_data = json.loads(meta.user.get("ptype")) if meta.user.get("ptype") else None,
82-
versioned = True
83-
)
84-
70+
71+
warnings.warn(
72+
"vetiver_pin_read will be removed in v1.0.0. Use classmethod VetiverModel.from_pin() instead",
73+
DeprecationWarning,
74+
)
75+
76+
version = (
77+
version if version is not None else _choose_version(board.pin_versions(name))
78+
)
79+
80+
v = VetiverModel.from_pin(
81+
board = board,
82+
name = name,
83+
version = version
84+
)
85+
8586
return v

vetiver/tests/test_build_vetiver_model.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import pytest
2+
import sklearn
23

3-
from vetiver.vetiver_model import VetiverModel
4+
import vetiver as vt
45
from vetiver.mock import get_mock_data, get_mock_model
56

67
import pandas as pd
7-
from numpy import int64
8+
import pydantic
9+
import pins
810

911
# Load data, model
1012
X_df, y = get_mock_data()
@@ -14,7 +16,7 @@
1416

1517
def test_vetiver_model_array_ptype():
1618
# build VetiverModel, no ptype
17-
vt1 = VetiverModel(
19+
vt1 = vt.VetiverModel(
1820
model=model,
1921
ptype_data=X_array,
2022
model_name="iris",
@@ -24,12 +26,13 @@ def test_vetiver_model_array_ptype():
2426
)
2527

2628
assert vt1.model == model
29+
assert isinstance(vt1.ptype.construct(), pydantic.BaseModel)
2730
assert list(vt1.ptype.__fields__.values())[0].type_ == int
2831

2932

3033
def test_vetiver_model_df_ptype():
3134
# build VetiverModel, df ptype_data
32-
vt2 = VetiverModel(
35+
vt2 = vt.VetiverModel(
3336
model=model,
3437
ptype_data=X_df,
3538
model_name="iris",
@@ -39,12 +42,13 @@ def test_vetiver_model_df_ptype():
3942
)
4043

4144
assert vt2.model == model
45+
assert isinstance(vt2.ptype.construct(), pydantic.BaseModel)
4246
assert list(vt2.ptype.__fields__.values())[0].type_ == int
4347

4448

4549
def test_vetiver_model_dict_ptype():
4650
dict_data = {"B": 0, "C": 0, "D": 0}
47-
vt3 = VetiverModel(
51+
vt3 = vt.VetiverModel(
4852
model=model,
4953
ptype_data=dict_data,
5054
model_name="iris",
@@ -54,12 +58,13 @@ def test_vetiver_model_dict_ptype():
5458
)
5559

5660
assert vt3.model == model
61+
assert isinstance(vt3.ptype.construct(), pydantic.BaseModel)
5762
assert list(vt3.ptype.__fields__.values())[0].type_ == int
5863

5964

6065
def test_vetiver_model_no_ptype():
6166
# build VetiverModel, no ptype
62-
vt4 = VetiverModel(
67+
vt4 = vt.VetiverModel(
6368
model=model,
6469
ptype_data=None,
6570
model_name="iris",
@@ -70,3 +75,22 @@ def test_vetiver_model_no_ptype():
7075

7176
assert vt4.model == model
7277
assert vt4.ptype == None
78+
79+
80+
def test_vetiver_model_from_pin():
81+
82+
v = vt.VetiverModel(
83+
model=model,
84+
ptype_data=X_df,
85+
model_name="model",
86+
versioned=None,
87+
description=None,
88+
metadata=None,
89+
)
90+
board = pins.board_temp(allow_pickle_read=True)
91+
vt.vetiver_pin_write(board=board, model=v)
92+
v2 = vt.VetiverModel.from_pin(board, "model")
93+
assert isinstance(v2, vt.VetiverModel)
94+
assert isinstance(v2.model, sklearn.base.BaseEstimator)
95+
assert isinstance(v2.ptype.construct(), pydantic.BaseModel)
96+
board.pin_delete("model")

vetiver/tests/test_pin_read.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

0 commit comments

Comments
 (0)