1+ from cgitb import reset
12import pytest
23
34from vetiver .vetiver_model import VetiverModel
45from vetiver import VetiverAPI
56from fastapi .testclient import TestClient
67
8+ import torch
79import torch .nn as nn
810import numpy as np
911
10- np .random .seed (500 )
11-
1212
1313def _build_torch_v ():
1414
@@ -59,7 +59,7 @@ def test_vetiver_build():
5959
6060
6161def test_torch_predict_ptype ():
62-
62+ torch . manual_seed ( 3 )
6363 x_train , torch_model = _build_torch_v ()
6464 v = VetiverModel (torch_model , save_ptype = True , ptype_data = x_train )
6565 v_api = VetiverAPI (v )
@@ -69,6 +69,22 @@ def test_torch_predict_ptype():
6969 response = client .post ("/predict/" , json = data )
7070
7171 assert response .status_code == 200 , response .text
72+ assert response .json () == {"prediction" :[- 4.060722351074219 ]}, response .text
73+
74+
75+ def test_torch_predict_ptype_batch ():
76+ torch .manual_seed (3 )
77+ x_train , torch_model = _build_torch_v ()
78+ v = VetiverModel (torch_model , save_ptype = True , ptype_data = x_train )
79+ v_api = VetiverAPI (v )
80+
81+ client = TestClient (v_api .app )
82+ data = [{"0" : 3.3 }, {"0" : 3.3 }]
83+ response = client .post ("/predict/" , json = data )
84+
85+ assert response .status_code == 200 , response .text
86+ assert response .json () == {"prediction" :[[- 4.060722351074219 ],[- 4.060722351074219 ]]}, response .text
87+
7288
7389def test_torch_predict_ptype_error ():
7490
@@ -80,19 +96,33 @@ def test_torch_predict_ptype_error():
8096 data = {"0" : "bad" }
8197 response = client .post ("/predict/" , json = data )
8298
83- assert response .status_code == 422 , response .text # value is not a valid float
99+ assert response .status_code == 422 , response .text # value is not a valid float
84100
85101
86102def test_torch_predict_no_ptype ():
103+ torch .manual_seed (3 )
104+ x_train , torch_model = _build_torch_v ()
105+ v = VetiverModel (torch_model , save_ptype = False , ptype_data = x_train )
106+ v_api = VetiverAPI (v , check_ptype = False )
107+
108+ client = TestClient (v_api .app )
109+ data = "3.3"
110+ response = client .post ("/predict/" , json = data )
111+ assert response .status_code == 200 , response .text
112+ assert response .json () == {"prediction" :[[- 4.060722351074219 ]]}, response .text
113+
87114
115+ def test_torch_predict_no_ptype_batch ():
116+ torch .manual_seed (3 )
88117 x_train , torch_model = _build_torch_v ()
89118 v = VetiverModel (torch_model , save_ptype = False , ptype_data = x_train )
90119 v_api = VetiverAPI (v , check_ptype = False )
91120
92121 client = TestClient (v_api .app )
93- data = ' 3.3'
122+ data = [[ " 3.3" ], [ "3.3" ]]
94123 response = client .post ("/predict/" , json = data )
95124 assert response .status_code == 200 , response .text
125+ assert response .json () == {"prediction" :[[- 4.060722351074219 ],[- 4.060722351074219 ]]}, response .text
96126
97127
98128def test_torch_predict_no_ptype_error ():
@@ -102,6 +132,6 @@ def test_torch_predict_no_ptype_error():
102132 v_api = VetiverAPI (v , check_ptype = False )
103133
104134 client = TestClient (v_api .app )
105- data = ' bad'
135+ data = " bad"
106136 with pytest .raises (ValueError ):
107137 client .post ("/predict/" , json = data )
0 commit comments