Skip to content

Commit e398220

Browse files
committed
Fix inference_fn
1 parent 0a80470 commit e398220

1 file changed

Lines changed: 20 additions & 7 deletions

File tree

sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@
3131
import json
3232
import logging
3333
import time
34-
from typing import Iterable
35-
from typing import Optional
36-
from typing import Tuple
3734
from typing import Any
3835
from typing import Dict
36+
from typing import Iterable
3937
from typing import List
38+
from typing import Optional
39+
from typing import Sequence
40+
from typing import Tuple
4041

4142
import apache_beam as beam
4243
from apache_beam.io.filesystems import FileSystems
@@ -135,12 +136,24 @@ def process(self, kv: Tuple[str, str]):
135136

136137

137138
def _torchvision_detection_inference_fn(
138-
model, batch: List[torch.Tensor], device: str) -> List[Dict[str, Any]]:
139-
"""Custom inference for TorchVision detection models.
140-
141-
TorchVision detection models expect: List[Tensor] (each: CHW float [0..1]).
139+
batch: Sequence[torch.Tensor],
140+
model: torch.nn.Module,
141+
device: torch.device,
142+
inference_args: Optional[dict[str, Any]] = None,
143+
model_id: Optional[str] = None,
144+
) -> List[Dict[str, Any]]:
145+
"""Inference function for TorchVision detection models.
146+
147+
TorchVision detection models expect List[Tensor] where each tensor is:
148+
- shape: [3, H, W]
149+
- dtype: float32
150+
- values: [0..1]
142151
"""
152+
del inference_args
153+
del model_id
154+
143155
with torch.no_grad():
156+
# Ensure tensors are on device
144157
inputs = []
145158
for t in batch:
146159
if isinstance(t, torch.Tensor):

0 commit comments

Comments
 (0)