Skip to content

Commit 17c3fd4

Browse files
authored
Merge pull request #126 from rstudio/metadata
Refactor metadata
2 parents 612d957 + 2c83fa4 commit 17c3fd4

17 files changed

Lines changed: 140 additions & 140 deletions

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ jobs:
7878
python-version: 3.8
7979
- name: Install dependencies
8080
run: |
81+
python -m pip install --upgrade pip
8182
python -m pip install ".[dev]"
8283
python -m pip install --upgrade git+https://github.com/rstudio/vetiver-python@${{ github.sha }}
8384
- name: run Docker

docs/source/advancedusage/custom_handler.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class CustomHandler(BaseHandler):
1212
super().__init__(model, ptype_data)
1313

1414
model_type = staticmethod(lambda: newmodeltype)
15+
pip_name = "scikit-learn" # pkg name on pip, used for tracking pkg versions
1516

1617
def handler_predict(self, input_data, check_ptype: bool):
1718
"""

vetiver/attach_pkgs.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import tempfile
22
import os
3-
from vetiver import VetiverModel
3+
from .vetiver_model import VetiverModel
4+
from .meta import VetiverMeta
45

56

67
def load_pkgs(model: VetiverModel = None, packages: list = None, path=""):
@@ -19,8 +20,12 @@ def load_pkgs(model: VetiverModel = None, packages: list = None, path=""):
1920
required_pkgs = ["vetiver"]
2021
if packages:
2122
required_pkgs = list(set(required_pkgs + packages))
22-
if model.metadata.get("required_pkgs"):
23-
required_pkgs = list(set(required_pkgs + model.metadata.get("required_pkgs")))
23+
24+
if isinstance(model.metadata, dict):
25+
model.metadata = VetiverMeta.from_dict(model.metadata)
26+
27+
if model.metadata.required_pkgs:
28+
required_pkgs = list(set(required_pkgs + model.metadata.required_pkgs))
2429

2530
tmp = tempfile.NamedTemporaryFile(suffix=".in", delete=False)
2631
tmp.close()

vetiver/handlers/base.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from contextlib import suppress
44

55
from ..prototype import vetiver_create_prototype
6-
from ..meta import _model_meta
6+
from ..meta import VetiverMeta
77

88

99
class InvalidModelError(Exception):
@@ -43,7 +43,7 @@ def create_handler(model, prototype_data):
4343
>>> model = vetiver.mock.get_mock_model()
4444
>>> handler = vetiver.create_handler(model, X)
4545
>>> handler.describe()
46-
"Scikit-learn <class 'sklearn.dummy.DummyRegressor'> model"
46+
'A scikit-learn DummyRegressor model'
4747
"""
4848

4949
raise InvalidModelError(
@@ -79,19 +79,20 @@ def __init__(self, model, prototype_data):
7979

8080
def describe(self):
8181
"""Create description for model"""
82-
desc = f"{self.model.__class__} model"
82+
83+
pip_name = self.pip_name if hasattr(self, "pip_name") else ""
84+
obj_name = type(self.model).__qualname__
85+
86+
desc = f"A {pip_name} {obj_name} model"
87+
8388
return desc
8489

85-
def create_meta(
86-
user: list = None,
87-
version: str = None,
88-
url: str = None,
89-
required_pkgs: list = [],
90-
):
90+
def create_meta(self, metadata):
9191
"""Create metadata for a model"""
92-
meta = _model_meta(user, version, url, required_pkgs)
9392

94-
return meta
93+
pip_name = self.pip_name if hasattr(self, "pip_name") else None
94+
95+
return VetiverMeta.from_dict(metadata, pip_name)
9596

9697
def construct_prototype(self):
9798
"""Create data prototype for a model

vetiver/handlers/sklearn.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pandas as pd
22
import sklearn
33

4-
from ..meta import _model_meta
54
from .base import BaseHandler
65

76

@@ -15,23 +14,7 @@ class SKLearnHandler(BaseHandler):
1514
"""
1615

1716
model_class = staticmethod(lambda: sklearn.base.BaseEstimator)
18-
19-
def describe(self):
20-
"""Create description for sklearn model"""
21-
desc = f"Scikit-learn {self.model.__class__} model"
22-
return desc
23-
24-
def create_meta(
25-
user: list = None,
26-
version: str = None,
27-
url: str = None,
28-
required_pkgs: list = [],
29-
):
30-
"""Create metadata for sklearn model"""
31-
required_pkgs = required_pkgs + ["scikit-learn"]
32-
meta = _model_meta(user, version, url, required_pkgs)
33-
34-
return meta
17+
pip_name = "scikit-learn"
3518

3619
def handler_predict(self, input_data, check_prototype):
3720
"""Generates method for /predict endpoint in VetiverAPI

vetiver/handlers/statsmodels.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pandas as pd
22

3-
from ..meta import _model_meta
43
from .base import BaseHandler
54

65
sm_exists = True
@@ -20,23 +19,8 @@ class StatsmodelsHandler(BaseHandler):
2019
"""
2120

2221
model_class = staticmethod(lambda: statsmodels.base.wrapper.ResultsWrapper)
23-
24-
def describe(self):
25-
"""Create description for statsmodels model"""
26-
desc = f"Statsmodels {self.model.__class__} model."
27-
return desc
28-
29-
def create_meta(
30-
user: list = None,
31-
version: str = None,
32-
url: str = None,
33-
required_pkgs: list = [],
34-
):
35-
"""Create metadata for statsmodel"""
36-
required_pkgs = required_pkgs + ["statsmodels"]
37-
meta = _model_meta(user, version, url, required_pkgs)
38-
39-
return meta
22+
if sm_exists:
23+
pip_name = "statsmodels"
4024

4125
def handler_predict(self, input_data, check_prototype):
4226
"""Generates method for /predict endpoint in VetiverAPI

vetiver/handlers/torch.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22

3-
from ..meta import _model_meta
43
from .base import BaseHandler
54

65
torch_exists = True
@@ -20,23 +19,8 @@ class TorchHandler(BaseHandler):
2019
"""
2120

2221
model_class = staticmethod(lambda: torch.nn.Module)
23-
24-
def describe(self):
25-
"""Create description for torch model"""
26-
desc = f"Pytorch model of type {type(self.model)}"
27-
return desc
28-
29-
def create_meta(
30-
user: list = None,
31-
version: str = None,
32-
url: str = None,
33-
required_pkgs: list = [],
34-
):
35-
"""Create metadata for torch model"""
36-
required_pkgs = required_pkgs + ["torch"]
37-
meta = _model_meta(user, version, url, required_pkgs)
38-
39-
return meta
22+
if torch_exists:
23+
pip_name = "torch"
4024

4125
def handler_predict(self, input_data, check_prototype):
4226
"""Generates method for /predict endpoint in VetiverAPI

vetiver/handlers/xgboost.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pandas as pd
22

3-
from ..meta import _model_meta
43
from .base import BaseHandler
54

65
xgb_exists = True
@@ -20,23 +19,8 @@ class XGBoostHandler(BaseHandler):
2019
"""
2120

2221
model_class = staticmethod(lambda: xgboost.Booster)
23-
24-
def describe(self):
25-
"""Create description for xgboost model"""
26-
desc = f"XGBoost {self.model.__class__} model."
27-
return desc
28-
29-
def create_meta(
30-
user: list = None,
31-
version: str = None,
32-
url: str = None,
33-
required_pkgs: list = [],
34-
):
35-
"""Create metadata for xgboost"""
36-
required_pkgs = required_pkgs + ["xgboost"]
37-
meta = _model_meta(user, version, url, required_pkgs)
38-
39-
return meta
22+
if xgb_exists:
23+
pip_name = "xgboost"
4024

4125
def handler_predict(self, input_data, check_prototype):
4226
"""Generates method for /predict endpoint in VetiverAPI

vetiver/meta.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,33 @@
1-
def _model_meta(
2-
user: dict = None, version: str = None, url: str = None, required_pkgs: list = None
3-
):
4-
"""Populate relevant metadata for VetiverModel
5-
6-
Args
7-
----
8-
user: dict
9-
Extra user-defined information
10-
version: str
11-
Model version, generally populated from pins
12-
url: str
13-
Discoverable URL for API
14-
required_pkgs: list
15-
Packages necessary to make predictions
16-
"""
17-
meta = {
18-
"user": user,
19-
"version": version,
20-
"url": url,
21-
"required_pkgs": required_pkgs,
22-
}
23-
return meta
1+
from dataclasses import dataclass, asdict, field
2+
from typing import Mapping
3+
4+
5+
@dataclass
6+
class VetiverMeta:
7+
"""Metadata in a VetiverModel"""
8+
9+
user: "dict | None" = field(default_factory=dict)
10+
version: "str | None" = None
11+
url: "str | None" = None
12+
required_pkgs: "list | None" = field(default_factory=list)
13+
14+
def to_dict(self) -> Mapping:
15+
data = asdict(self)
16+
17+
return data
18+
19+
@classmethod
20+
def from_dict(cls, metadata, pip_name=None) -> "VetiverMeta":
21+
22+
metadata = {} if metadata is None else metadata
23+
24+
user = metadata.get("user", metadata)
25+
version = metadata.get("version", None)
26+
url = metadata.get("url", None)
27+
required_pkgs = metadata.get("required_pkgs", [])
28+
29+
if pip_name:
30+
if not list(filter(lambda x: pip_name in x, required_pkgs)):
31+
required_pkgs = required_pkgs + [f"{pip_name}"]
32+
33+
return cls(user, version, url, required_pkgs)

vetiver/pin_read_write.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .vetiver_model import VetiverModel
2+
from .meta import VetiverMeta
23
from .utils import inform
34
import warnings
45
import logging
@@ -54,15 +55,22 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True):
5455
# convert older model's ptype to prototype
5556
if hasattr(model, "ptype"):
5657
model.prototype = model.ptype
58+
delattr(model, "ptype")
59+
# metadata is dict
60+
if isinstance(model.metadata, dict):
61+
model.metadata = VetiverMeta.from_dict(model.metadata)
5762

5863
board.pin_write(
5964
model.model,
6065
name=model.model_name,
6166
type="joblib",
6267
description=model.description,
6368
metadata={
64-
"required_pkgs": model.metadata.get("required_pkgs"),
65-
"prototype": None if model.prototype is None else model.prototype().json(),
69+
"user": model.metadata.user,
70+
"vetiver_meta": {
71+
"required_pkgs": model.metadata.required_pkgs,
72+
"prototype": None if not model.prototype else model.prototype().json(),
73+
},
6674
},
6775
versioned=versioned,
6876
)

0 commit comments

Comments
 (0)