Skip to content

Commit 281509a

Browse files
committed
Refactoring
1 parent 6a09efb commit 281509a

5 files changed

Lines changed: 74 additions & 83 deletions

File tree

.github/workflows/beam_Inference_Python_Benchmarks_Dataflow.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ jobs:
9292
${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Sentiment_Streaming_DistilBert_Base_Uncased.txt
9393
${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Sentiment_Batch_DistilBert_Base_Uncased.txt
9494
${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_VLLM_Gemma_Batch.txt
95-
${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Image_Detection.txt
95+
${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Image_Object_Detection.txt
9696
# The env variables are created and populated in the test-arguments-action as "<github.job>_test_arguments_<argument_file_paths_index>"
9797
- name: get current time
9898
run: echo "NOW_UTC=$(date '+%m%d%H%M%S' --utc)" >> $GITHUB_ENV
@@ -197,8 +197,8 @@ jobs:
197197
with:
198198
gradle-command: :sdks:python:apache_beam:testing:load_tests:run
199199
arguments: |
200-
-PloadTest.mainClass=apache_beam.testing.benchmarks.inference.pytorch_object_detection_benchmarks \
200+
-PloadTest.mainClass=apache_beam.testing.benchmarks.inference.pytorch_image_object_detection_benchmarks \
201201
-Prunner=DataflowRunner \
202202
-PpythonVersion=3.10 \
203-
-PloadTest.requirementsTxtFile=apache_beam/ml/inference/pytorch_object_detection_requirements.txt \
203+
-PloadTest.requirementsTxtFile=apache_beam/ml/inference/pytorch_image_object_detection_requirements.txt \
204204
'-PloadTest.args=${{ env.beam_Inference_Python_Benchmarks_Dataflow_test_arguments_9 }} --mode=batch --job_name=benchmark-tests-pytorch-object_detection-batch-${{env.NOW_UTC}} --output_table=apache-beam-testing.beam_run_inference.result_torch_inference_object_detection_batch' \

.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Image_Detection.txt renamed to .github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Image_Object_Detection.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
--autoscaling_algorithm=NONE
2222
--staging_location=gs://temp-storage-for-perf-tests/loadtests
2323
--temp_location=gs://temp-storage-for-perf-tests/loadtests
24-
--requirements_file=apache_beam/ml/inference/pytorch_object_detection_requirements.txt
24+
--requirements_file=apache_beam/ml/inference/pytorch_image_object_detection_requirements.txt
2525
--publish_to_big_query=true
2626
--metrics_dataset=beam_run_inference
2727
--metrics_table=result_torch_inference_object_detection_batch

sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py

Lines changed: 66 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@
5252
import torch
5353
import PIL.Image as PILImage
5454

55-
5655
# ============ Utility & Preprocessing ============
5756

57+
5858
def now_millis() -> int:
5959
return int(time.time() * 1000)
6060

@@ -71,9 +71,9 @@ def load_image_from_uri(uri: str) -> bytes:
7171
return f.read()
7272

7373

74-
def decode_to_tensor(
75-
image_bytes: bytes,
76-
resize_shorter_side: Optional[int] = None) -> torch.Tensor:
74+
def decode_to_tens(
75+
image_bytes: bytes,
76+
resize_shorter_side: Optional[int] = None) -> torch.Tensor:
7777
"""Decode bytes -> RGB PIL -> optional resize -> float tensor [0..1], CHW.
7878
7979
Note: TorchVision detection models apply their own normalization internally.
@@ -105,6 +105,7 @@ def sha1_hex(s: str) -> str:
105105

106106
# ============ DoFns ============
107107

108+
108109
class MakeKeyDoFn(beam.DoFn):
109110
"""Produce (image_id, uri) where image_id is stable for dedup and keys."""
110111
def process(self, element: str):
@@ -123,19 +124,21 @@ def process(self, kv: Tuple[str, str]):
123124
start = now_millis()
124125
try:
125126
b = load_image_from_uri(uri)
126-
tensor = decode_to_tensor(b, resize_shorter_side=self.resize_shorter_side)
127+
tensor = decode_to_tens(b, resize_shorter_side=self.resize_shorter_side)
127128
preprocess_ms = now_millis() - start
128-
yield image_id, {"tensor": tensor, "preprocess_ms": preprocess_ms, "uri": uri}
129+
yield image_id, {
130+
"tensor": tensor, "preprocess_ms": preprocess_ms, "uri": uri
131+
}
129132
except Exception as e:
130133
logging.warning("Decode failed for %s (%s): %s", image_id, uri, e)
131134
return
132135

133136

134137
def _torchvision_detection_inference_fn(
135-
model, batch: List[torch.Tensor], device: str) -> List[Dict[str, Any]]:
138+
model, batch: List[torch.Tensor], device: str) -> List[Dict[str, Any]]:
136139
"""Custom inference for TorchVision detection models.
137140
138-
TorchVision detection models expect: List[Tensor] (each tensor: CHW float [0..1]).
141+
TorchVision detection models expect: List[Tensor] (each: CHW float [0..1]).
139142
"""
140143
with torch.no_grad():
141144
inputs = []
@@ -152,10 +155,7 @@ def _torchvision_detection_inference_fn(
152155
class PostProcessDoFn(beam.DoFn):
153156
"""PredictionResult -> dict row for BQ."""
154157
def __init__(
155-
self,
156-
model_name: str,
157-
score_threshold: float,
158-
max_detections: int):
158+
self, model_name: str, score_threshold: float, max_detections: int):
159159
self.model_name = model_name
160160
self.score_threshold = score_threshold
161161
self.max_detections = max_detections
@@ -191,9 +191,9 @@ def _extract_detection(self, inference_obj: Any) -> Dict[str, Any]:
191191
continue
192192
box = boxes_list[i] # [x1,y1,x2,y2]
193193
dets.append({
194-
"label_id": int(labels_list[i]),
195-
"score": score,
196-
"box": [float(box[0]), float(box[1]), float(box[2]), float(box[3])],
194+
"label_id": int(labels_list[i]),
195+
"score": score,
196+
"box": [float(box[0]), float(box[1]), float(box[2]), float(box[3])],
197197
})
198198
if len(dets) >= self.max_detections:
199199
break
@@ -214,8 +214,7 @@ def process(self, kv: Tuple[str, PredictionResult]):
214214

215215
if not isinstance(inference_obj, dict):
216216
logging.warning(
217-
"Unexpected inference type for %s: %s", image_id, type(inference_obj)
218-
)
217+
"Unexpected inf-ce type for %s: %s", image_id, type(inference_obj))
219218
yield {
220219
"image_id": image_id,
221220
"model_name": self.model_name,
@@ -238,32 +237,32 @@ def process(self, kv: Tuple[str, PredictionResult]):
238237

239238
# ============ Args & Helpers ============
240239

240+
241241
def parse_known_args(argv):
242242
parser = argparse.ArgumentParser()
243243

244244
# I/O & runtime
245245
parser.add_argument('--mode', default='batch', choices=['batch'])
246246
parser.add_argument(
247-
'--output_table',
248-
required=True,
249-
help='BigQuery output table: dataset.table')
247+
'--output_table',
248+
required=True,
249+
help='BigQuery output table: dataset.table')
250250
parser.add_argument(
251-
'--publish_to_big_query', default='true', choices=['true', 'false'])
251+
'--publish_to_big_query', default='true', choices=['true', 'false'])
252252
parser.add_argument(
253-
'--input',
254-
required=True,
255-
help='GCS path to file with image URIs')
253+
'--input', required=True, help='GCS path to file with image URIs')
256254

257255
# Model & inference
258256
parser.add_argument(
259-
'--pretrained_model_name',
260-
default='fasterrcnn_resnet50_fpn',
261-
help=('TorchVision detection model name '
257+
'--pretrained_model_name',
258+
default='fasterrcnn_resnet50_fpn',
259+
help=(
260+
'TorchVision detection model name '
262261
'(e.g., fasterrcnn_resnet50_fpn)'))
263262
parser.add_argument(
264-
'--model_state_dict_path',
265-
required=True,
266-
help='GCS path to a state_dict .pth for the chosen model')
263+
'--model_state_dict_path',
264+
required=True,
265+
help='GCS path to a state_dict .pth for the chosen model')
267266
parser.add_argument('--device', default='GPU', choices=['CPU', 'GPU'])
268267

269268
# Batch sizing (no right-fitting)
@@ -311,8 +310,9 @@ def create_torchvision_detection_model(model_name: str):
311310

312311
# ============ Main pipeline ============
313312

313+
314314
def run(
315-
argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
315+
argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
316316
known_args, pipeline_args = parse_known_args(argv)
317317

318318
pipeline_options = PipelineOptions(pipeline_args)
@@ -328,72 +328,63 @@ def run(
328328
batch_size = int(known_args.inference_batch_size)
329329

330330
model_handler = PytorchModelHandlerTensor(
331-
model_class=lambda: create_torchvision_detection_model(
332-
known_args.pretrained_model_name
333-
),
334-
model_params={},
335-
state_dict_path=known_args.model_state_dict_path,
336-
device=device,
337-
inference_batch_size=batch_size,
338-
inference_fn=_torchvision_detection_inference_fn,
331+
model_class=lambda: create_torchvision_detection_model(
332+
known_args.pretrained_model_name
333+
),
334+
model_params={},
335+
state_dict_path=known_args.model_state_dict_path,
336+
device=device,
337+
inference_batch_size=batch_size,
338+
inference_fn=_torchvision_detection_inference_fn,
339339
)
340340

341341
pipeline = test_pipeline or beam.Pipeline(options=pipeline_options)
342342

343343
pcoll = (
344-
pipeline
345-
| 'ReadURIsBatch' >> beam.Create(
346-
list(read_gcs_file_lines(known_args.input))
347-
)
348-
| 'FilterEmptyBatch' >> beam.Filter(lambda s: s.strip())
349-
)
344+
pipeline
345+
| 'ReadURIsBatch' >> beam.Create(
346+
list(read_gcs_file_lines(known_args.input)))
347+
| 'FilterEmptyBatch' >> beam.Filter(lambda s: s.strip()))
350348

351-
keyed = (
352-
pcoll
353-
| 'MakeKey' >> beam.ParDo(MakeKeyDoFn())
354-
)
349+
keyed = (pcoll | 'MakeKey' >> beam.ParDo(MakeKeyDoFn()))
355350

356351
# Batch exactly-once behavior:
357352
# 1) Dedup by key within the run to ensure stable writes.
358353
# 2) Use FILE_LOADS for BQ to avoid streaming insert duplicates in retries.
359354
keyed = keyed | 'DistinctByKey' >> beam.Distinct()
360355

361356
preprocessed = (
362-
keyed
363-
| 'DecodePreprocess' >> beam.ParDo(
364-
DecodePreprocessDoFn(resize_shorter_side=resize_shorter_side))
365-
)
357+
keyed
358+
| 'DecodePreprocess' >> beam.ParDo(
359+
DecodePreprocessDoFn(resize_shorter_side=resize_shorter_side)))
366360

367361
to_infer = (
368-
preprocessed
369-
| 'ToKeyedTensor' >> beam.Map(lambda kv: (kv[0], kv[1]["tensor"]))
370-
)
362+
preprocessed
363+
| 'ToKeyedTensor' >> beam.Map(lambda kv: (kv[0], kv[1]["tensor"])))
371364

372365
predictions = (
373-
to_infer
374-
| 'RunInference' >> RunInference(KeyedModelHandler(model_handler))
375-
)
366+
to_infer
367+
| 'RunInference' >> RunInference(KeyedModelHandler(model_handler)))
376368

377369
results = (
378-
predictions
379-
| 'PostProcess' >> beam.ParDo(
380-
PostProcessDoFn(
381-
model_name=known_args.pretrained_model_name,
382-
score_threshold=known_args.score_threshold,
383-
max_detections=known_args.max_detections))
384-
)
370+
predictions
371+
| 'PostProcess' >> beam.ParDo(
372+
PostProcessDoFn(
373+
model_name=known_args.pretrained_model_name,
374+
score_threshold=known_args.score_threshold,
375+
max_detections=known_args.max_detections)))
385376

386377
if known_args.publish_to_big_query == 'true':
387378
_ = (
388-
results
389-
| 'WriteToBigQuery' >> beam.io.WriteToBigQuery(
390-
known_args.output_table,
391-
schema=('image_id:STRING, model_name:STRING, '
392-
'detections:STRING, num_detections:INT64, infer_ms:INT64'),
393-
write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,
394-
create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
395-
method=beam.io.WriteToBigQuery.Method.FILE_LOADS)
396-
)
379+
results
380+
| 'WriteToBigQuery' >> beam.io.WriteToBigQuery(
381+
known_args.output_table,
382+
schema=(
383+
'image_id:STRING, model_name:STRING, '
384+
'detections:STRING, num_detections:INT64, infer_ms:INT64'),
385+
write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,
386+
create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
387+
method=beam.io.WriteToBigQuery.Method.FILE_LOADS))
397388

398389
result = pipeline.run()
399390
result.wait_until_finish(duration=1800000) # 30 min

sdks/python/apache_beam/ml/inference/pytorch_object_detection_requirements.txt renamed to sdks/python/apache_beam/ml/inference/pytorch_image_object_detection_requirements.txt

File renamed without changes.

sdks/python/apache_beam/testing/benchmarks/inference/pytorch_object_detection_benchmarks.py renamed to sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_object_detection_benchmarks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ class PytorchImageObjectDetectionBenchmarkTest(DataflowCostBenchmark):
2626
def __init__(self):
2727
self.metrics_namespace = 'BeamML_PyTorch'
2828
super().__init__(
29-
metrics_namespace=self.metrics_namespace,
30-
pcollection='PostProcess.out0')
29+
metrics_namespace=self.metrics_namespace,
30+
pcollection='PostProcess.out0')
3131

3232
def test(self):
3333
extra_opts = {}
3434
extra_opts['input'] = self.pipeline.get_option('input_file')
3535
self.result = pytorch_image_object_detection.run(
36-
self.pipeline.get_full_options_as_args(**extra_opts),
37-
test_pipeline=self.pipeline)
36+
self.pipeline.get_full_options_as_args(**extra_opts),
37+
test_pipeline=self.pipeline)
3838

3939

4040
if __name__ == '__main__':

0 commit comments

Comments
 (0)