|
2 | 2 | from ..ptype import _vetiver_create_ptype |
3 | 3 | import numpy as np |
4 | 4 |
|
| 5 | +torch_exists = True |
| 6 | +try: |
| 7 | + import torch |
| 8 | +except ImportError: |
| 9 | + torch_exists = False |
| 10 | + |
| 11 | + |
5 | 12 | class TorchHandler: |
6 | 13 | """Handler class for creating VetiverModels with torch. |
7 | 14 |
|
@@ -77,21 +84,22 @@ def handler_predict(self, input_data, check_ptype): |
77 | 84 | prediction |
78 | 85 | Prediction from model |
79 | 86 | """ |
80 | | - import torch |
81 | | - |
82 | | - if check_ptype == True: |
83 | | - input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype) |
84 | | - prediction = self.model(torch.from_numpy(input_data)) |
85 | | - |
86 | | - # do not check ptype |
| 87 | + if torch_exists: |
| 88 | + if check_ptype == True: |
| 89 | + input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype) |
| 90 | + prediction = self.model(torch.from_numpy(input_data)) |
| 91 | + |
| 92 | + # do not check ptype |
| 93 | + else: |
| 94 | + batch = True |
| 95 | + if not isinstance(input_data, list): |
| 96 | + batch = False |
| 97 | + input_data = input_data.split(",") # user delimiter ? |
| 98 | + input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype) |
| 99 | + if not batch: |
| 100 | + input_data = input_data.reshape(1, -1) |
| 101 | + prediction = self.model(torch.from_numpy(input_data)) |
87 | 102 | else: |
88 | | - batch = True |
89 | | - if not isinstance(input_data, list): |
90 | | - batch = False |
91 | | - input_data = input_data.split(",") # user delimiter ? |
92 | | - input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype) |
93 | | - if not batch: |
94 | | - input_data = input_data.reshape(1, -1) |
95 | | - prediction = self.model(torch.from_numpy(input_data)) |
| 103 | + raise ImportError("Cannot import `torch`.") |
96 | 104 |
|
97 | 105 | return prediction |
0 commit comments