5252import torch
5353import PIL .Image as PILImage
5454
55-
5655# ============ Utility & Preprocessing ============
5756
57+
5858def now_millis () -> int :
5959 return int (time .time () * 1000 )
6060
@@ -71,9 +71,9 @@ def load_image_from_uri(uri: str) -> bytes:
7171 return f .read ()
7272
7373
74- def decode_to_tensor (
75- image_bytes : bytes ,
76- resize_shorter_side : Optional [int ] = None ) -> torch .Tensor :
74+ def decode_to_tens (
75+ image_bytes : bytes ,
76+ resize_shorter_side : Optional [int ] = None ) -> torch .Tensor :
7777 """Decode bytes -> RGB PIL -> optional resize -> float tensor [0..1], CHW.
7878
7979 Note: TorchVision detection models apply their own normalization internally.
@@ -105,6 +105,7 @@ def sha1_hex(s: str) -> str:
105105
106106# ============ DoFns ============
107107
108+
108109class MakeKeyDoFn (beam .DoFn ):
109110 """Produce (image_id, uri) where image_id is stable for dedup and keys."""
110111 def process (self , element : str ):
@@ -123,19 +124,21 @@ def process(self, kv: Tuple[str, str]):
123124 start = now_millis ()
124125 try :
125126 b = load_image_from_uri (uri )
126- tensor = decode_to_tensor (b , resize_shorter_side = self .resize_shorter_side )
127+ tensor = decode_to_tens (b , resize_shorter_side = self .resize_shorter_side )
127128 preprocess_ms = now_millis () - start
128- yield image_id , {"tensor" : tensor , "preprocess_ms" : preprocess_ms , "uri" : uri }
129+ yield image_id , {
130+ "tensor" : tensor , "preprocess_ms" : preprocess_ms , "uri" : uri
131+ }
129132 except Exception as e :
130133 logging .warning ("Decode failed for %s (%s): %s" , image_id , uri , e )
131134 return
132135
133136
134137def _torchvision_detection_inference_fn (
135- model , batch : List [torch .Tensor ], device : str ) -> List [Dict [str , Any ]]:
138+ model , batch : List [torch .Tensor ], device : str ) -> List [Dict [str , Any ]]:
136139 """Custom inference for TorchVision detection models.
137140
138- TorchVision detection models expect: List[Tensor] (each tensor : CHW float [0..1]).
141+ TorchVision detection models expect: List[Tensor] (each: CHW float [0..1]).
139142 """
140143 with torch .no_grad ():
141144 inputs = []
@@ -152,10 +155,7 @@ def _torchvision_detection_inference_fn(
152155class PostProcessDoFn (beam .DoFn ):
153156 """PredictionResult -> dict row for BQ."""
154157 def __init__ (
155- self ,
156- model_name : str ,
157- score_threshold : float ,
158- max_detections : int ):
158+ self , model_name : str , score_threshold : float , max_detections : int ):
159159 self .model_name = model_name
160160 self .score_threshold = score_threshold
161161 self .max_detections = max_detections
@@ -191,9 +191,9 @@ def _extract_detection(self, inference_obj: Any) -> Dict[str, Any]:
191191 continue
192192 box = boxes_list [i ] # [x1,y1,x2,y2]
193193 dets .append ({
194- "label_id" : int (labels_list [i ]),
195- "score" : score ,
196- "box" : [float (box [0 ]), float (box [1 ]), float (box [2 ]), float (box [3 ])],
194+ "label_id" : int (labels_list [i ]),
195+ "score" : score ,
196+ "box" : [float (box [0 ]), float (box [1 ]), float (box [2 ]), float (box [3 ])],
197197 })
198198 if len (dets ) >= self .max_detections :
199199 break
@@ -214,8 +214,7 @@ def process(self, kv: Tuple[str, PredictionResult]):
214214
215215 if not isinstance (inference_obj , dict ):
216216 logging .warning (
217- "Unexpected inference type for %s: %s" , image_id , type (inference_obj )
218- )
217+ "Unexpected inf-ce type for %s: %s" , image_id , type (inference_obj ))
219218 yield {
220219 "image_id" : image_id ,
221220 "model_name" : self .model_name ,
@@ -238,32 +237,32 @@ def process(self, kv: Tuple[str, PredictionResult]):
238237
239238# ============ Args & Helpers ============
240239
240+
241241def parse_known_args (argv ):
242242 parser = argparse .ArgumentParser ()
243243
244244 # I/O & runtime
245245 parser .add_argument ('--mode' , default = 'batch' , choices = ['batch' ])
246246 parser .add_argument (
247- '--output_table' ,
248- required = True ,
249- help = 'BigQuery output table: dataset.table' )
247+ '--output_table' ,
248+ required = True ,
249+ help = 'BigQuery output table: dataset.table' )
250250 parser .add_argument (
251- '--publish_to_big_query' , default = 'true' , choices = ['true' , 'false' ])
251+ '--publish_to_big_query' , default = 'true' , choices = ['true' , 'false' ])
252252 parser .add_argument (
253- '--input' ,
254- required = True ,
255- help = 'GCS path to file with image URIs' )
253+ '--input' , required = True , help = 'GCS path to file with image URIs' )
256254
257255 # Model & inference
258256 parser .add_argument (
259- '--pretrained_model_name' ,
260- default = 'fasterrcnn_resnet50_fpn' ,
261- help = ('TorchVision detection model name '
257+ '--pretrained_model_name' ,
258+ default = 'fasterrcnn_resnet50_fpn' ,
259+ help = (
260+ 'TorchVision detection model name '
262261 '(e.g., fasterrcnn_resnet50_fpn)' ))
263262 parser .add_argument (
264- '--model_state_dict_path' ,
265- required = True ,
266- help = 'GCS path to a state_dict .pth for the chosen model' )
263+ '--model_state_dict_path' ,
264+ required = True ,
265+ help = 'GCS path to a state_dict .pth for the chosen model' )
267266 parser .add_argument ('--device' , default = 'GPU' , choices = ['CPU' , 'GPU' ])
268267
269268 # Batch sizing (no right-fitting)
@@ -311,8 +310,9 @@ def create_torchvision_detection_model(model_name: str):
311310
312311# ============ Main pipeline ============
313312
313+
314314def run (
315- argv = None , save_main_session = True , test_pipeline = None ) -> PipelineResult :
315+ argv = None , save_main_session = True , test_pipeline = None ) -> PipelineResult :
316316 known_args , pipeline_args = parse_known_args (argv )
317317
318318 pipeline_options = PipelineOptions (pipeline_args )
@@ -328,72 +328,63 @@ def run(
328328 batch_size = int (known_args .inference_batch_size )
329329
330330 model_handler = PytorchModelHandlerTensor (
331- model_class = lambda : create_torchvision_detection_model (
332- known_args .pretrained_model_name
333- ),
334- model_params = {},
335- state_dict_path = known_args .model_state_dict_path ,
336- device = device ,
337- inference_batch_size = batch_size ,
338- inference_fn = _torchvision_detection_inference_fn ,
331+ model_class = lambda : create_torchvision_detection_model (
332+ known_args .pretrained_model_name
333+ ),
334+ model_params = {},
335+ state_dict_path = known_args .model_state_dict_path ,
336+ device = device ,
337+ inference_batch_size = batch_size ,
338+ inference_fn = _torchvision_detection_inference_fn ,
339339 )
340340
341341 pipeline = test_pipeline or beam .Pipeline (options = pipeline_options )
342342
343343 pcoll = (
344- pipeline
345- | 'ReadURIsBatch' >> beam .Create (
346- list (read_gcs_file_lines (known_args .input ))
347- )
348- | 'FilterEmptyBatch' >> beam .Filter (lambda s : s .strip ())
349- )
344+ pipeline
345+ | 'ReadURIsBatch' >> beam .Create (
346+ list (read_gcs_file_lines (known_args .input )))
347+ | 'FilterEmptyBatch' >> beam .Filter (lambda s : s .strip ()))
350348
351- keyed = (
352- pcoll
353- | 'MakeKey' >> beam .ParDo (MakeKeyDoFn ())
354- )
349+ keyed = (pcoll | 'MakeKey' >> beam .ParDo (MakeKeyDoFn ()))
355350
356351 # Batch exactly-once behavior:
357352 # 1) Dedup by key within the run to ensure stable writes.
358353 # 2) Use FILE_LOADS for BQ to avoid streaming insert duplicates in retries.
359354 keyed = keyed | 'DistinctByKey' >> beam .Distinct ()
360355
361356 preprocessed = (
362- keyed
363- | 'DecodePreprocess' >> beam .ParDo (
364- DecodePreprocessDoFn (resize_shorter_side = resize_shorter_side ))
365- )
357+ keyed
358+ | 'DecodePreprocess' >> beam .ParDo (
359+ DecodePreprocessDoFn (resize_shorter_side = resize_shorter_side )))
366360
367361 to_infer = (
368- preprocessed
369- | 'ToKeyedTensor' >> beam .Map (lambda kv : (kv [0 ], kv [1 ]["tensor" ]))
370- )
362+ preprocessed
363+ | 'ToKeyedTensor' >> beam .Map (lambda kv : (kv [0 ], kv [1 ]["tensor" ])))
371364
372365 predictions = (
373- to_infer
374- | 'RunInference' >> RunInference (KeyedModelHandler (model_handler ))
375- )
366+ to_infer
367+ | 'RunInference' >> RunInference (KeyedModelHandler (model_handler )))
376368
377369 results = (
378- predictions
379- | 'PostProcess' >> beam .ParDo (
380- PostProcessDoFn (
381- model_name = known_args .pretrained_model_name ,
382- score_threshold = known_args .score_threshold ,
383- max_detections = known_args .max_detections ))
384- )
370+ predictions
371+ | 'PostProcess' >> beam .ParDo (
372+ PostProcessDoFn (
373+ model_name = known_args .pretrained_model_name ,
374+ score_threshold = known_args .score_threshold ,
375+ max_detections = known_args .max_detections )))
385376
386377 if known_args .publish_to_big_query == 'true' :
387378 _ = (
388- results
389- | 'WriteToBigQuery' >> beam .io .WriteToBigQuery (
390- known_args .output_table ,
391- schema = ('image_id:STRING, model_name:STRING, '
392- 'detections :STRING, num_detections:INT64, infer_ms:INT64' ),
393- write_disposition = beam . io . BigQueryDisposition . WRITE_APPEND ,
394- create_disposition = beam .io .BigQueryDisposition .CREATE_IF_NEEDED ,
395- method = beam .io .WriteToBigQuery . Method . FILE_LOADS )
396- )
379+ results
380+ | 'WriteToBigQuery' >> beam .io .WriteToBigQuery (
381+ known_args .output_table ,
382+ schema = (
383+ 'image_id :STRING, model_name:STRING, '
384+ 'detections:STRING, num_detections:INT64, infer_ms:INT64' ) ,
385+ write_disposition = beam .io .BigQueryDisposition .WRITE_APPEND ,
386+ create_disposition = beam .io .BigQueryDisposition . CREATE_IF_NEEDED ,
387+ method = beam . io . WriteToBigQuery . Method . FILE_LOADS ) )
397388
398389 result = pipeline .run ()
399390 result .wait_until_finish (duration = 1800000 ) # 30 min
0 commit comments