Skip to content

Commit 6143546

Browse files
committed
no longer need to know column name to predict
1 parent 9d3577a commit 6143546

2 files changed

Lines changed: 63 additions & 36 deletions

File tree

vetiver/handlers/spacy.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from .base import BaseHandler
22
from ..prototype import vetiver_create_prototype
3+
from ..helpers import api_data_to_frame
4+
35
import pandas as pd
46

57
spacy_exists = True
@@ -31,10 +33,16 @@ def construct_prototype(self):
3133
prototype :
3234
Input data prototype for spacy model
3335
"""
34-
if self.prototype_data is not None:
35-
raise TypeError
36+
if self.prototype_data is None:
37+
text_column_name = "text"
38+
39+
else:
40+
if len(self.prototype_data.columns) != 1:
41+
raise TypeError("Expected 1 column of text data")
3642

37-
prototype = vetiver_create_prototype(pd.DataFrame({"text": ["text"]}))
43+
text_column_name = self.prototype_data.columns[0]
44+
45+
prototype = vetiver_create_prototype(pd.DataFrame({text_column_name: ["text"]}))
3846

3947
return prototype
4048

@@ -60,7 +68,9 @@ def handler_predict(self, input_data, check_prototype):
6068

6169
response_body = []
6270

63-
for doc in self.model.pipe(input_data.text):
71+
input_data = api_data_to_frame(input_data)
72+
73+
for doc in self.model.pipe(input_data.iloc[:, 0]):
6474
response_body.append(doc.to_json())
6575

6676
return pd.Series(response_body)

vetiver/tests/test_spacy.py

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,63 +27,80 @@ def animal_component_function(doc):
2727
matcher.add("ANIMAL", animals)
2828
nlp.add_pipe("animals")
2929

30-
return vetiver.VetiverModel(nlp, "animals")
30+
return nlp
3131

3232

3333
@pytest.fixture
34-
def vetiver_client(spacy_model): # With check_prototype=True
35-
app = vetiver.VetiverAPI(spacy_model, check_prototype=True)
34+
def vetiver_client_with_prototype(spacy_model): # With check_prototype=True
35+
df = pd.DataFrame({"new_column": ["one", "two", "three"]})
36+
v = vetiver.VetiverModel(spacy_model, "animals", prototype_data=df)
37+
app = vetiver.VetiverAPI(v, check_prototype=True)
3638
app.app.root_path = "/predict"
3739
client = TestClient(app.app)
3840

3941
return client
4042

4143

4244
@pytest.fixture
43-
<<<<<<< HEAD
44-
def vetiver_client_check_ptype_false(spacy_model): # With check_prototype=False
45-
app = vetiver.VetiverAPI(spacy_model, check_prototype=True)
46-
<<<<<<< HEAD
47-
=======
48-
def vetiver_client_check_ptype_false(spacy_model): # With check_ptype=True
49-
app = vetiver.VetiverAPI(spacy_model, check_ptype=False)
50-
>>>>>>> b600e38 (testing spacy)
51-
=======
52-
>>>>>>> 64158ab (specify in server to not preprocess data)
45+
def vetiver_client_no_prototype(spacy_model): # With check_prototype=False
46+
v = vetiver.VetiverModel(spacy_model, "animals")
47+
app = vetiver.VetiverAPI(v, check_prototype=False)
5348
app.app.root_path = "/predict"
5449
client = TestClient(app.app)
5550

5651
return client
5752

5853

59-
def test_vetiver_post(vetiver_client):
60-
df = pd.DataFrame({"text": ["one", "my turtle is smarter than my dog"]})
54+
def test_vetiver_predict_with_prototype(vetiver_client_with_prototype):
55+
df = pd.DataFrame({"new_column": ["turtles", "i have a dog"]})
6156

62-
response = vetiver.predict(endpoint=vetiver_client, data=df)
57+
response = vetiver.predict(endpoint=vetiver_client_with_prototype, data=df)
6358

6459
assert isinstance(response, pd.DataFrame), response
6560
assert response.to_dict() == {
6661
"predict": {
67-
0: {
68-
"text": "one",
62+
"0": {
63+
"text": "turtles",
6964
"ents": [],
70-
"sents": [{"start": 0, "end": 3}],
71-
"tokens": [{"id": 0, "start": 0, "end": 3}],
65+
"sents": [{"start": 0, "end": 7}],
66+
"tokens": [{"id": 0, "start": 0, "end": 7}],
7267
},
73-
1: {
74-
"text": "my turtle is smarter than my dog",
75-
"ents": [
76-
{"start": 3, "end": 9, "label": "ANIMAL"},
77-
{"start": 29, "end": 32, "label": "ANIMAL"},
68+
"1": {
69+
"text": "i have a dog",
70+
"ents": [{"start": 9, "end": 12, "label": "ANIMAL"}],
71+
"tokens": [
72+
{"id": 0, "start": 0, "end": 1},
73+
{"id": 1, "start": 2, "end": 6},
74+
{"id": 2, "start": 7, "end": 8},
75+
{"id": 3, "start": 9, "end": 12},
7876
],
77+
},
78+
}
79+
}
80+
81+
82+
def test_vetiver_predict_no_prototype(vetiver_client_no_prototype):
83+
df = pd.DataFrame({"uhhh": ["turtles", "i have a dog"]})
84+
85+
response = vetiver.predict(endpoint=vetiver_client_no_prototype, data=df)
86+
87+
assert isinstance(response, pd.DataFrame), response
88+
assert response.to_dict() == {
89+
"predict": {
90+
"0": {
91+
"text": "turtles",
92+
"ents": [],
93+
"sents": [{"start": 0, "end": 7}],
94+
"tokens": [{"id": 0, "start": 0, "end": 7}],
95+
},
96+
"1": {
97+
"text": "i have a dog",
98+
"ents": [{"start": 9, "end": 12, "label": "ANIMAL"}],
7999
"tokens": [
80-
{"id": 0, "start": 0, "end": 2},
81-
{"id": 1, "start": 3, "end": 9},
82-
{"id": 2, "start": 10, "end": 12},
83-
{"id": 3, "start": 13, "end": 20},
84-
{"id": 4, "start": 21, "end": 25},
85-
{"id": 5, "start": 26, "end": 28},
86-
{"id": 6, "start": 29, "end": 32},
100+
{"id": 0, "start": 0, "end": 1},
101+
{"id": 1, "start": 2, "end": 6},
102+
{"id": 2, "start": 7, "end": 8},
103+
{"id": 3, "start": 9, "end": 12},
87104
],
88105
},
89106
}

0 commit comments

Comments
 (0)