Skip to content

Commit aeef641

Browse files
committed
fixes
1 parent f233463 commit aeef641

2 files changed

Lines changed: 5 additions & 17 deletions

File tree

vetiver/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@
1515
from .write_docker import *
1616
from .write_fastapi import *
1717
from .handlers._interface import *
18-
from .handlers.scikitlearn import *
19-
from .handlers.pytorch import *
18+
from .handlers.sklearn import SKLearnHandler
19+
from .handlers.torch import TorchHandler

vetiver/handlers/_interface.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11
from typing import Any
2-
from vetiver.handlers import pytorch, scikitlearn
2+
from vetiver.handlers import torch, sklearn
33
from functools import singledispatch
4-
import sklearn
5-
6-
torch_exists = True
7-
try:
8-
import torch
9-
except ImportError:
10-
torch_exists = False
114

125
class InvalidModelError(Exception):
136
"""
@@ -90,11 +83,6 @@ def create_handler(model, ptype_data):
9083
"""
9184
raise InvalidModelError(message=CREATE_PTYPE_TPL.format(_model_type=type(model)))
9285

93-
@create_handler.register
94-
def _(model: sklearn.base.BaseEstimator, ptype_data: Any):
95-
return scikitlearn.SKLearnHandler(model, ptype_data)
86+
create_handler.register(sklearn.SKLearnHandler.base_class, sklearn.SKLearnHandler)
9687

97-
if torch_exists:
98-
@create_handler.register
99-
def _(model: torch.nn.Module, ptype_data: Any):
100-
return pytorch.TorchHandler(model, ptype_data)
88+
create_handler.register(torch.TorchHandler.base_class, torch.TorchHandler)

0 commit comments

Comments
 (0)