Skip to content

Commit eee87b5

Browse files
authored
Merge pull request #29 from isabelizimm/dev-optional-deps
make torch optional dependency
2 parents 96b06c5 + d05db52 commit eee87b5

4 files changed

Lines changed: 38 additions & 23 deletions

File tree

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
- name: Install dependencies
2424
run: |
2525
python -m pip install --upgrade pip
26-
python -m pip install -e .[dev]
26+
python -m pip install -e .[dev,torch]
2727
- name: Run Tests
2828
run: |
2929
pytest --cov --cov-report xml .

setup.cfg

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ install_requires =
3131
uvicorn
3232
scikit-learn
3333
nest-asyncio
34-
torch
3534
requests
3635
pins @ git+https://github.com/machow/pins-python@f9391b5937b3a13b2a2c841855ca94ab281b2cef
3736

@@ -43,3 +42,6 @@ dev =
4342
sphinx-autodoc-typehints
4443
sphinx-book-theme
4544
myst-parser
45+
46+
torch =
47+
torch

vetiver/handlers/_interface.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from vetiver.handlers import pytorch_vt, sklearn_vt
2-
from torch import nn
3-
import sklearn
2+
import sklearn
3+
4+
torch_exists = True
5+
try:
6+
import torch
7+
except ImportError:
8+
torch_exists = False
49

510
def create_translator(model, ptype_data, save_ptype):
611
"""check for model type to handle prediction
@@ -15,11 +20,11 @@ def create_translator(model, ptype_data, save_ptype):
1520
pytorch_vt.TorchHandler or sklearn_vt.SKLearnHandler
1621
Handler class for specified model type
1722
"""
23+
if torch_exists:
24+
if isinstance(model, torch.nn.Module):
25+
return pytorch_vt.TorchHandler(model, ptype_data, save_ptype)
1826

19-
if isinstance(model, nn.Module):
20-
return pytorch_vt.TorchHandler(model, ptype_data, save_ptype)
21-
22-
elif isinstance(model, sklearn.base.BaseEstimator):
27+
if isinstance(model, sklearn.base.BaseEstimator):
2328
return sklearn_vt.SKLearnHandler(model, ptype_data, save_ptype)
2429

2530
else:

vetiver/handlers/pytorch_vt.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
from ..ptype import _vetiver_create_ptype
33
import numpy as np
44

5+
torch_exists = True
6+
try:
7+
import torch
8+
except ImportError:
9+
torch_exists = False
10+
11+
512
class TorchHandler:
613
"""Handler class for creating VetiverModels with torch.
714
@@ -77,21 +84,22 @@ def handler_predict(self, input_data, check_ptype):
7784
prediction
7885
Prediction from model
7986
"""
80-
import torch
81-
82-
if check_ptype == True:
83-
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
84-
prediction = self.model(torch.from_numpy(input_data))
85-
86-
# do not check ptype
87+
if torch_exists:
88+
if check_ptype == True:
89+
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
90+
prediction = self.model(torch.from_numpy(input_data))
91+
92+
# do not check ptype
93+
else:
94+
batch = True
95+
if not isinstance(input_data, list):
96+
batch = False
97+
input_data = input_data.split(",") # user delimiter ?
98+
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
99+
if not batch:
100+
input_data = input_data.reshape(1, -1)
101+
prediction = self.model(torch.from_numpy(input_data))
87102
else:
88-
batch = True
89-
if not isinstance(input_data, list):
90-
batch = False
91-
input_data = input_data.split(",") # user delimiter ?
92-
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
93-
if not batch:
94-
input_data = input_data.reshape(1, -1)
95-
prediction = self.model(torch.from_numpy(input_data))
103+
raise ImportError("Cannot import `torch`.")
96104

97105
return prediction

0 commit comments

Comments
 (0)