-
Notifications
You must be signed in to change notification settings - Fork 58
Expand file tree
/
Copy pathbase_inference_engine.py
More file actions
92 lines (83 loc) · 2.72 KB
/
base_inference_engine.py
File metadata and controls
92 lines (83 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Import dependencies
from abc import ABC, abstractmethod
from inference.exceptions import InvalidModelConfiguration, ModelNotLoaded, ApplicationError
# Abstract inference engine class
class AbstractInferenceEngine(ABC):
def __init__(self, model_path):
"""
Takes a model path and calls the load function.
:param model_path: The model's path
:return:
"""
self.labels = []
self.configuration = {}
self.model_path = model_path
try:
self.validate_configuration()
except ApplicationError as e:
raise e
try:
self.load()
except ApplicationError as e:
raise e
except Exception as e:
raise ModelNotLoaded()
@abstractmethod
def load(self):
"""
Loads the model based on the underlying implementation.
"""
pass
@abstractmethod
def free(self):
"""
Performs any manual memory implementation required to when unloading a model.
Will be called when the class's destructor is called.
"""
pass
@abstractmethod
async def run(self, input_data, draw_boxes, predict_batch):
"""
Performs the required inference based on the underlying implementation of this class.
Could be used to return classification predictions, object detection coordinates...
:param predict_batch: Boolean
:param input_data: A single image
:param draw_boxes: Used to draw bounding boxes on image instead of returning them
:return: A bounding-box
"""
pass
@abstractmethod
async def run_batch(self, input_data, draw_boxes, predict_batch):
"""
Iterates over images and returns a prediction for each one.
:param predict_batch: Boolean
:param input_data: List of images
:param draw_boxes: Used to draw bounding boxes on image instead of returning them
:return: List of bounding-boxes
"""
pass
@abstractmethod
def validate_configuration(self):
"""
Validates that the model and its files are valid based on the underlying implementation's requirements.
Can check for configuration values, folder structure...
"""
pass
@abstractmethod
def set_configuration(self, data):
"""
Takes the configuration from the config.json file
:param data: Json data
:return:
"""
pass
@abstractmethod
def validate_json_configuration(self, data):
"""
Validates the configuration of the config.json file.
:param data: Json data
:return:
"""
pass
def __del__(self):
self.free()