Skip to content

Commit d660102

Browse files
committed
use staticmethod for new types
1 parent b7cdfeb commit d660102

5 files changed

Lines changed: 33 additions & 15 deletions

File tree

.github/workflows/tests.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,21 @@ jobs:
6464
- name: Run tests
6565
run: |
6666
pytest vetiver -m 'rsc_test'
67+
68+
test-no-torch:
69+
name: "Test no-torch"
70+
runs-on: ubuntu-latest
71+
if: ${{ !github.event.pull_request.head.repo.fork }}
72+
steps:
73+
- uses: actions/checkout@v2
74+
- uses: actions/setup-python@v2
75+
with:
76+
python-version: 3.8
77+
- name: Install dependencies
78+
run: |
79+
python -m pip install --upgrade pip
80+
python -m pip install -e .[dev]
81+
82+
- name: Run tests
83+
run: |
84+
pytest vetiver/tests/test_sklearn.py

vetiver/handlers/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ class VetiverHandler:
6868
def __init_subclass__(cls, **kwargs):
6969
super().__init_subclass__(**kwargs)
7070
with suppress(AttributeError, NameError):
71-
create_handler.register(cls.base_class, cls)
71+
create_handler.register(cls.base_class(), cls)
7272

73-
def __new__(cls, value=None):
74-
implementation_cls = create_handler.registry[type(value)]
75-
return super().__new__(implementation_cls)
73+
# def __new__(cls, value=None):
74+
# implementation_cls = create_handler.registry[type(value)]
75+
# return super().__new__(implementation_cls)
7676

7777
def __init__(self, model, ptype_data):
7878
self.model = model

vetiver/handlers/sklearn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class SKLearnHandler(VetiverHandler):
1414
a trained sklearn model
1515
"""
1616

17-
base_class = sklearn.base.BaseEstimator
17+
base_class = staticmethod(lambda: sklearn.base.BaseEstimator)
1818

1919
def __init__(self, model, ptype_data):
2020
super().__init__(model, ptype_data)
@@ -54,7 +54,7 @@ def handler_predict(self, input_data, check_ptype):
5454
Prediction from model
5555
"""
5656

57-
if check_ptype == True:
57+
if check_ptype:
5858
if isinstance(input_data, pd.DataFrame):
5959
prediction = self.model.predict(input_data)
6060
else:

vetiver/handlers/torch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class TorchHandler(VetiverHandler):
1919
a trained torch model
2020
"""
2121

22-
base_class = torch.nn.Module
22+
base_class = staticmethod(lambda: torch.nn.Module)
2323

2424
def __init__(self, model, ptype_data):
2525
super().__init__(model, ptype_data)
@@ -59,7 +59,7 @@ def handler_predict(self, input_data, check_ptype):
5959
Prediction from model
6060
"""
6161
if torch_exists:
62-
if check_ptype == True:
62+
if check_ptype:
6363
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
6464
prediction = self.model(torch.from_numpy(input_data))
6565

vetiver/tests/test_pytorch.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
from vetiver.vetiver_model import VetiverModel
22
from vetiver import VetiverAPI
33
from fastapi.testclient import TestClient
4-
5-
import torch
6-
import torch.nn as nn
4+
import pytest
75
import numpy as np
86

7+
torch = pytest.importorskip("torch")
8+
9+
import torch # noqa
10+
911

1012
def _build_torch_v():
1113

12-
# Hyper-parameters
1314
input_size = 1
1415
output_size = 1
1516

16-
# Toy dataset
1717
x_train = np.array(
1818
[
1919
[3.3],
@@ -34,8 +34,8 @@ def _build_torch_v():
3434
],
3535
dtype=np.float32,
3636
)
37-
# Linear regression model
38-
torch_model = nn.Linear(input_size, output_size)
37+
38+
torch_model = torch.nn.Linear(input_size, output_size)
3939
return x_train, torch_model
4040

4141

0 commit comments

Comments
 (0)