Skip to content

Commit 9e3b580

Browse files
committed
updates from review
1 parent aeef641 commit 9e3b580

5 files changed

Lines changed: 10 additions & 5 deletions

File tree

vetiver/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@
1414
from .meta import *
1515
from .write_docker import *
1616
from .write_fastapi import *
17-
from .handlers._interface import *
17+
from .handlers._interface import create_handler
1818
from .handlers.sklearn import SKLearnHandler
1919
from .handlers.torch import TorchHandler

vetiver/handlers/_interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from typing import Any
2-
from vetiver.handlers import torch, sklearn
1+
from vetiver.handlers import torch, sklearn, base
32
from functools import singledispatch
43

54
class InvalidModelError(Exception):
@@ -20,7 +19,6 @@ def __init__(
2019
type {_model_type}. If your model is not one of \
2120
(scikit-learn, torch), you should create and register \
2221
the handler. Here is a template for such a function: \
23-
from pydantic import create_model
2422
from vetiver.handlers._interface import create_handler
2523
from vetiver.handlers.base import VetiverHandler
2624
@@ -86,3 +84,5 @@ def create_handler(model, ptype_data):
8684
create_handler.register(sklearn.SKLearnHandler.base_class, sklearn.SKLearnHandler)
8785

8886
create_handler.register(torch.TorchHandler.base_class, torch.TorchHandler)
87+
88+
#create_handler.register(base.VetiverHandler, lambda model: model)

vetiver/handlers/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ def __init__(self, model, ptype_data):
1616
self.model = model
1717
self.ptype_data = ptype_data
1818

19+
def __class__(self):
20+
...
21+
1922
def create_description(self):
2023
"""Create description for model"""
2124
desc = f"{self.model.__class__} model"
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ class SKLearnHandler(VetiverHandler):
1313
model : sklearn.base.BaseEstimator
1414
a trained sklearn model
1515
"""
16-
16+
base_class = sklearn.base.BaseEstimator
17+
1718
def __init__(self, model, ptype_data):
1819
super().__init__(model, ptype_data)
1920

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class TorchHandler(VetiverHandler):
1818
model : nn.Module
1919
a trained torch model
2020
"""
21+
base_class = torch.nn.Module
2122
def __init__(self, model, ptype_data):
2223
super().__init__(model, ptype_data)
2324

0 commit comments

Comments
 (0)