-
Notifications
You must be signed in to change notification settings - Fork 58
Expand file tree
/
Copy pathinference_engines_factory.py
More file actions
32 lines (31 loc) · 1.39 KB
/
inference_engines_factory.py
File metadata and controls
32 lines (31 loc) · 1.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Import dependencies
import os
import json
from inference.exceptions import ModelNotFound, ApplicationError, InvalidModelConfiguration, InferenceEngineNotFound, ModelNotLoaded
# Inference engine factory class
class InferenceEngineFactory:
@staticmethod
def get_engine(path_to_model):
"""
Reads the model's inference engine from the model's configuration and calls the right inference engine class.
:param path_to_model: Model's path
:return: The model's instance
"""
if not os.path.exists(path_to_model):
raise ModelNotFound()
try:
configuration = json.loads(open(os.path.join(path_to_model, 'config.json')).read())
except Exception:
raise InvalidModelConfiguration('config.json not found or corrupted')
try:
inference_engine_name = configuration['inference_engine_name']
except Exception:
raise InvalidModelConfiguration('missing inference engine name in config.json')
try:
# import one of the available inference engine class (in this project there's only one), and return a
# model instance
return getattr(__import__(inference_engine_name), 'InferenceEngine')(path_to_model)
except ApplicationError as e:
raise e
except Exception as e:
raise InferenceEngineNotFound(inference_engine_name)