1+ from cgitb import reset
12import pytest
23
34from vetiver .vetiver_model import VetiverModel
45from vetiver import VetiverAPI
56from fastapi .testclient import TestClient
67
8+ import torch
79import torch .nn as nn
810import numpy as np
911
10- np .random .seed (500 )
11-
1212
1313def _build_torch_v ():
1414
@@ -59,7 +59,7 @@ def test_vetiver_build():
5959
6060
6161def test_torch_predict_ptype ():
62-
62+ torch . manual_seed ( 3 )
6363 x_train , torch_model = _build_torch_v ()
6464 v = VetiverModel (torch_model , save_ptype = True , ptype_data = x_train )
6565 v_api = VetiverAPI (v )
@@ -69,10 +69,11 @@ def test_torch_predict_ptype():
6969 response = client .post ("/predict/" , json = data )
7070
7171 assert response .status_code == 200 , response .text
72+ assert response .json () == {"prediction" :[- 4.060722351074219 ]}, response .text
7273
7374
7475def test_torch_predict_ptype_batch ():
75-
76+ torch . manual_seed ( 3 )
7677 x_train , torch_model = _build_torch_v ()
7778 v = VetiverModel (torch_model , save_ptype = True , ptype_data = x_train )
7879 v_api = VetiverAPI (v )
@@ -82,6 +83,7 @@ def test_torch_predict_ptype_batch():
8283 response = client .post ("/predict/" , json = data )
8384
8485 assert response .status_code == 200 , response .text
86+ assert response .json () == {"prediction" :[[- 4.060722351074219 ],[- 4.060722351074219 ]]}, response .text
8587
8688
8789def test_torch_predict_ptype_error ():
@@ -98,7 +100,7 @@ def test_torch_predict_ptype_error():
98100
99101
100102def test_torch_predict_no_ptype ():
101-
103+ torch . manual_seed ( 3 )
102104 x_train , torch_model = _build_torch_v ()
103105 v = VetiverModel (torch_model , save_ptype = False , ptype_data = x_train )
104106 v_api = VetiverAPI (v , check_ptype = False )
@@ -107,6 +109,20 @@ def test_torch_predict_no_ptype():
107109 data = "3.3"
108110 response = client .post ("/predict/" , json = data )
109111 assert response .status_code == 200 , response .text
112+ assert response .json () == {"prediction" :[[- 4.060722351074219 ]]}, response .text
113+
114+
115+ def test_torch_predict_no_ptype_batch ():
116+ torch .manual_seed (3 )
117+ x_train , torch_model = _build_torch_v ()
118+ v = VetiverModel (torch_model , save_ptype = False , ptype_data = x_train )
119+ v_api = VetiverAPI (v , check_ptype = False )
120+
121+ client = TestClient (v_api .app )
122+ data = [["3.3" ], ["3.3" ]]
123+ response = client .post ("/predict/" , json = data )
124+ assert response .status_code == 200 , response .text
125+ assert response .json () == {"prediction" :[[- 4.060722351074219 ],[- 4.060722351074219 ]]}, response .text
110126
111127
112128def test_torch_predict_no_ptype_error ():
0 commit comments