Skip to content

Commit 169b022

Browse files
committed
Fix inference result
1 parent e398220 commit 169b022

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,12 @@ def _extract_detection(self, inference_obj: Any) -> Dict[str, Any]:
219219
def process(self, kv: Tuple[str, PredictionResult]):
220220
image_id, pred = kv
221221

222-
# pred.inference should be torchvision dict for this element,
223-
# but keep robust fallback.
224-
inference_obj = pred.inference
222+
# pred can be PredictionResult OR raw torchvision dict.
223+
if hasattr(pred, "inference"):
224+
inference_obj = pred.inference
225+
else:
226+
inference_obj = pred
227+
225228
if isinstance(inference_obj, list) and len(inference_obj) == 1:
226229
inference_obj = inference_obj[0]
227230

0 commit comments

Comments
 (0)