1- from vetiver import mock , VetiverModel , VetiverAPI
1+ from vetiver import (
2+ mock ,
3+ VetiverModel ,
4+ VetiverAPI ,
5+ vetiver_create_prototype ,
6+ InvalidPTypeError ,
7+ )
8+ from pydantic import BaseModel , conint
29from fastapi .testclient import TestClient
10+ import numpy as np
311import pytest
412import sys
513
614
715@pytest .fixture
816def vetiver_model ():
17+ np .random .seed (500 )
918 X , y = mock .get_mock_data ()
1019 model = mock .get_mock_model ().fit (X , y )
1120 v = VetiverModel (
@@ -25,6 +34,31 @@ def client(vetiver_model):
2534 return TestClient (app .app )
2635
2736
37+ @pytest .fixture
38+ def complex_prototype_model ():
39+ np .random .seed (500 )
40+
41+ class CustomPrototype (BaseModel ):
42+ B : conint (gt = 42 )
43+ C : conint (gt = 42 )
44+ D : conint (gt = 42 )
45+
46+ X , y = mock .get_mock_data ()
47+ model = mock .get_mock_model ().fit (X , y )
48+ v = VetiverModel (
49+ model = model ,
50+ # move to model_construct for pydantic 3
51+ prototype_data = CustomPrototype .construct (),
52+ model_name = "my_model" ,
53+ versioned = None ,
54+ description = "A regression model for testing purposes" ,
55+ )
56+ # dont actually want to make predictions, just for looking at schema
57+ app = VetiverAPI (v , check_prototype = False )
58+
59+ return TestClient (app .app )
60+
61+
2862def test_get_ping (client ):
2963 response = client .get ("/ping" )
3064 assert response .status_code == 200 , response .text
@@ -46,3 +80,40 @@ def test_get_metadata(client):
4680 "required_pkgs" : ["scikit-learn" ],
4781 "python_version" : list (sys .version_info ), # JSON will return a list
4882 }
83+
84+
85+ def test_get_prototype (client , vetiver_model ):
86+ response = client .get ("/prototype" )
87+ assert response .status_code == 200 , response .text
88+ assert response .json () == {
89+ "properties" : {
90+ "B" : {"default" : 55 , "type" : "integer" },
91+ "C" : {"default" : 65 , "type" : "integer" },
92+ "D" : {"default" : 17 , "type" : "integer" },
93+ },
94+ "title" : "prototype" ,
95+ "type" : "object" ,
96+ }
97+
98+ assert (
99+ vetiver_model .prototype .construct ().dict ()
100+ == vetiver_create_prototype (response .json ()).construct ().dict ()
101+ )
102+
103+
104+ def test_complex_prototype (complex_prototype_model ):
105+ response = complex_prototype_model .get ("/prototype" )
106+ assert response .status_code == 200 , response .text
107+ assert response .json () == {
108+ "properties" : {
109+ "B" : {"exclusiveMinimum" : 42 , "type" : "integer" },
110+ "C" : {"exclusiveMinimum" : 42 , "type" : "integer" },
111+ "D" : {"exclusiveMinimum" : 42 , "type" : "integer" },
112+ },
113+ "required" : ["B" , "C" , "D" ],
114+ "title" : "CustomPrototype" ,
115+ "type" : "object" ,
116+ }
117+
118+ with pytest .raises (InvalidPTypeError ):
119+ vetiver_create_prototype (response .json ())
0 commit comments