Skip to content

Commit 9141aba

Browse files
committed
add appropriate helper for spacy
1 parent 6143546 commit 9141aba

3 files changed

Lines changed: 93 additions & 64 deletions

File tree

vetiver/handlers/spacy.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,17 @@ def construct_prototype(self):
3737
text_column_name = "text"
3838

3939
else:
40-
if len(self.prototype_data.columns) != 1:
40+
if (
41+
isinstance(self.prototype_data, pd.DataFrame)
42+
and len(self.prototype_data.columns) != 1
43+
):
4144
raise TypeError("Expected 1 column of text data")
4245

43-
text_column_name = self.prototype_data.columns[0]
46+
text_column_name = (
47+
self.prototype_data.columns[0]
48+
if isinstance(self.prototype_data, pd.DataFrame)
49+
else list(self.prototype_data.keys())[0]
50+
)
4451

4552
prototype = vetiver_create_prototype(pd.DataFrame({text_column_name: ["text"]}))
4653

vetiver/helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ def _(pred_data):
2828
return pd.DataFrame([dict(s) for s in pred_data])
2929

3030

31+
@api_data_to_frame.register(pd.DataFrame)
32+
def _pd_frame(pred_data):
33+
34+
return pred_data
35+
36+
3137
@api_data_to_frame.register(dict)
3238
def _dict(pred_data):
3339
return api_data_to_frame([pred_data])

vetiver/tests/test_spacy.py

Lines changed: 78 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,30 @@
55
import numpy as np # noqa
66
import pandas as pd # noqa
77
from fastapi.testclient import TestClient # noqa
8-
8+
from numpy import nan # noqa
99
import vetiver # noqa
1010

1111

12+
@spacy.language.Language.component("animals")
13+
def animal_component_function(doc):
14+
matches = matcher(doc) # noqa
15+
spans = [
16+
spacy.tokens.Span(doc, start, end, label="ANIMAL")
17+
for match_id, start, end in matches
18+
]
19+
doc.ents = spans
20+
return doc
21+
22+
23+
nlp = spacy.blank("en")
24+
animals = list(nlp.pipe(["dog", "cat", "turtle"]))
25+
matcher = spacy.matcher.PhraseMatcher(nlp.vocab)
26+
matcher.add("ANIMAL", animals)
27+
nlp.add_pipe("animals")
28+
29+
1230
@pytest.fixture
1331
def spacy_model():
14-
@spacy.language.Language.component("animals")
15-
def animal_component_function(doc):
16-
matches = matcher(doc) # noqa
17-
spans = [
18-
spacy.tokens.Span(doc, start, end, label="ANIMAL")
19-
for match_id, start, end in matches
20-
]
21-
doc.ents = spans
22-
return doc
23-
24-
nlp = spacy.blank("en")
25-
animals = list(nlp.pipe(["dog", "cat", "turtle"]))
26-
matcher = spacy.matcher.PhraseMatcher(nlp.vocab)
27-
matcher.add("ANIMAL", animals)
28-
nlp.add_pipe("animals")
29-
3032
return nlp
3133

3234

@@ -58,24 +60,23 @@ def test_vetiver_predict_with_prototype(vetiver_client_with_prototype):
5860

5961
assert isinstance(response, pd.DataFrame), response
6062
assert response.to_dict() == {
61-
"predict": {
62-
"0": {
63-
"text": "turtles",
64-
"ents": [],
65-
"sents": [{"start": 0, "end": 7}],
66-
"tokens": [{"id": 0, "start": 0, "end": 7}],
67-
},
68-
"1": {
69-
"text": "i have a dog",
70-
"ents": [{"start": 9, "end": 12, "label": "ANIMAL"}],
71-
"tokens": [
72-
{"id": 0, "start": 0, "end": 1},
73-
{"id": 1, "start": 2, "end": 6},
74-
{"id": 2, "start": 7, "end": 8},
75-
{"id": 3, "start": 9, "end": 12},
76-
],
77-
},
78-
}
63+
"0": {
64+
"text": "turtles",
65+
"ents": [],
66+
"sents": [{"start": 0, "end": 7}],
67+
"tokens": [{"id": 0, "start": 0, "end": 7}],
68+
},
69+
"1": {
70+
"text": "i have a dog",
71+
"ents": [{"start": 9, "end": 12, "label": "ANIMAL"}],
72+
"sents": nan,
73+
"tokens": [
74+
{"id": 0, "start": 0, "end": 1},
75+
{"id": 1, "start": 2, "end": 6},
76+
{"id": 2, "start": 7, "end": 8},
77+
{"id": 3, "start": 9, "end": 12},
78+
],
79+
},
7980
}
8081

8182

@@ -86,34 +87,49 @@ def test_vetiver_predict_no_prototype(vetiver_client_no_prototype):
8687

8788
assert isinstance(response, pd.DataFrame), response
8889
assert response.to_dict() == {
89-
"predict": {
90-
"0": {
91-
"text": "turtles",
92-
"ents": [],
93-
"sents": [{"start": 0, "end": 7}],
94-
"tokens": [{"id": 0, "start": 0, "end": 7}],
95-
},
96-
"1": {
97-
"text": "i have a dog",
98-
"ents": [{"start": 9, "end": 12, "label": "ANIMAL"}],
99-
"tokens": [
100-
{"id": 0, "start": 0, "end": 1},
101-
{"id": 1, "start": 2, "end": 6},
102-
{"id": 2, "start": 7, "end": 8},
103-
{"id": 3, "start": 9, "end": 12},
104-
],
105-
},
106-
}
90+
"0": {
91+
"text": "turtles",
92+
"ents": [],
93+
"sents": [{"start": 0, "end": 7}],
94+
"tokens": [{"id": 0, "start": 0, "end": 7}],
95+
},
96+
"1": {
97+
"text": "i have a dog",
98+
"ents": [{"start": 9, "end": 12, "label": "ANIMAL"}],
99+
"sents": nan,
100+
"tokens": [
101+
{"id": 0, "start": 0, "end": 1},
102+
{"id": 1, "start": 2, "end": 6},
103+
{"id": 2, "start": 7, "end": 8},
104+
{"id": 3, "start": 9, "end": 12},
105+
],
106+
},
107107
}
108108

109109

110-
# def test_serialize(spacy_model):
111-
# import pins
110+
def test_serialize_no_prototype(spacy_model):
111+
import pins
112112

113-
# board = pins.board_temp(allow_pickle_read=True)
114-
# vetiver.vetiver_pin_write(board=board, model=spacy_model)
115-
# assert isinstance(
116-
# board.pin_read("sentencizer"),
117-
# spacy.pipeline.sentencizer.Sentencizer,
118-
# )
119-
# board.pin_delete("sentencizer")
113+
board = pins.board_temp(allow_pickle_read=True)
114+
v = vetiver.VetiverModel(spacy_model, "animals")
115+
vetiver.vetiver_pin_write(board=board, model=v)
116+
v2 = vetiver.VetiverModel.from_pin(board, "animals")
117+
assert isinstance(
118+
v2.model,
119+
spacy.lang.en.English,
120+
)
121+
122+
123+
def test_serialize_prototype(spacy_model):
124+
import pins
125+
126+
board = pins.board_temp(allow_pickle_read=True)
127+
v = vetiver.VetiverModel(
128+
spacy_model, "animals", prototype_data=pd.DataFrame({"text": ["text"]})
129+
)
130+
vetiver.vetiver_pin_write(board=board, model=v)
131+
v2 = vetiver.VetiverModel.from_pin(board, "animals")
132+
assert isinstance(
133+
v2.model,
134+
spacy.lang.en.English,
135+
)

0 commit comments

Comments
 (0)