1- from abc import ABC , abstractmethod
2- import os
3- import importlib
4-
51from PIL import Image
62import numpy as np
73import torch
84
95from .catalog import PathManager , LABEL_MAP_CATALOG
10- from ..elements import *
6+ from ..base_layoutmodel import BaseLayoutModel
7+ from ...elements import Rectangle , TextBlock , Layout
118
129__all__ = ["Detectron2LayoutModel" ]
1310
1411
15- class BaseLayoutModel (ABC ):
16- @abstractmethod
17- def detect (self ):
18- pass
19-
20- # Add lazy loading mechanisms for layout models, refer to
21- # layoutparser.ocr.BaseOCRAgent
22- # TODO: Build a metaclass for lazy module loader
23- @property
24- @abstractmethod
25- def DEPENDENCIES (self ):
26- """DEPENDENCIES lists all necessary dependencies for the class."""
27- pass
28-
29- @property
30- @abstractmethod
31- def MODULES (self ):
32- """MODULES instructs how to import these necessary libraries."""
33- pass
34-
35- @classmethod
36- def _import_module (cls ):
37- for m in cls .MODULES :
38- if importlib .util .find_spec (m ["module_path" ]):
39- setattr (
40- cls , m ["import_name" ], importlib .import_module (m ["module_path" ])
41- )
42- else :
43- raise ModuleNotFoundError (
44- f"\n "
45- f"\n Please install the following libraries to support the class { cls .__name__ } :"
46- f"\n pip install { ' ' .join (cls .DEPENDENCIES )} "
47- f"\n "
48- )
49-
50- def __new__ (cls , * args , ** kwargs ):
51-
52- cls ._import_module ()
53- return super ().__new__ (cls )
54-
55-
5612class Detectron2LayoutModel (BaseLayoutModel ):
5713 """Create a Detectron2-based Layout Detection Model
5814
@@ -93,6 +49,7 @@ class Detectron2LayoutModel(BaseLayoutModel):
9349 },
9450 {"import_name" : "_config" , "module_path" : "detectron2.config" },
9551 ]
52+ DETECTOR_NAME = "detectron2"
9653
9754 def __init__ (
9855 self ,
@@ -111,18 +68,47 @@ def __init__(
11168 extra_config .extend (["MODEL.DEVICE" , "cpu" ])
11269
11370 cfg = self ._config .get_cfg ()
71+ config_path = self ._reconstruct_path_with_detector_name (config_path )
11472 config_path = PathManager .get_local_path (config_path )
11573 cfg .merge_from_file (config_path )
11674 cfg .merge_from_list (extra_config )
11775
11876 if model_path is not None :
77+ model_path = self ._reconstruct_path_with_detector_name (model_path )
11978 cfg .MODEL .WEIGHTS = model_path
12079 cfg .MODEL .DEVICE = "cuda" if torch .cuda .is_available () else "cpu"
12180 self .cfg = cfg
12281
12382 self .label_map = label_map
12483 self ._create_model ()
12584
85+ def _reconstruct_path_with_detector_name (self , path : str ) -> str :
86+ """This function will add the detector name (detectron2) into the
87+ lp model config path to get the "canonical" model name.
88+
89+ For example, for a given config_path `lp://HJDataset/faster_rcnn_R_50_FPN_3x/config`,
90+ it will transform it into `lp://detectron2/HJDataset/faster_rcnn_R_50_FPN_3x/config`.
91+ However, if the config_path already contains the detector name, we won't change it.
92+
93+ This function is a general step to support multiple backends in the layout-parser
94+ library.
95+
96+ Args:
97+ path (str): The given input path that might or might not contain the detector name.
98+
99+ Returns:
100+ str: a modified path that contains the detector name.
101+ """
102+ if path .startswith ("lp://" ): # TODO: Move "lp://" to a constant
103+ model_name = path [len ("lp://" ) :]
104+ model_name_segments = model_name .split ("/" )
105+ if (
106+ len (model_name_segments ) == 3
107+ and "detectron2" not in model_name_segments
108+ ):
109+ return "lp://" + self .DETECTOR_NAME + "/" + path [len ("lp://" ) :]
110+ return path
111+
126112 def gather_output (self , outputs ):
127113
128114 instance_pred = outputs ["instances" ].to ("cpu" )
0 commit comments