|
31 | 31 | import json |
32 | 32 | import logging |
33 | 33 | import time |
34 | | -from typing import Iterable |
35 | | -from typing import Optional |
36 | | -from typing import Tuple |
37 | 34 | from typing import Any |
38 | 35 | from typing import Dict |
| 36 | +from typing import Iterable |
39 | 37 | from typing import List |
| 38 | +from typing import Optional |
| 39 | +from typing import Sequence |
| 40 | +from typing import Tuple |
40 | 41 |
|
41 | 42 | import apache_beam as beam |
42 | 43 | from apache_beam.io.filesystems import FileSystems |
@@ -135,12 +136,24 @@ def process(self, kv: Tuple[str, str]): |
135 | 136 |
|
136 | 137 |
|
137 | 138 | def _torchvision_detection_inference_fn( |
138 | | - model, batch: List[torch.Tensor], device: str) -> List[Dict[str, Any]]: |
139 | | - """Custom inference for TorchVision detection models. |
140 | | -
|
141 | | - TorchVision detection models expect: List[Tensor] (each: CHW float [0..1]). |
| 139 | + batch: Sequence[torch.Tensor], |
| 140 | + model: torch.nn.Module, |
| 141 | + device: torch.device, |
| 142 | + inference_args: Optional[dict[str, Any]] = None, |
| 143 | + model_id: Optional[str] = None, |
| 144 | +) -> List[Dict[str, Any]]: |
| 145 | + """Inference function for TorchVision detection models. |
| 146 | +
|
| 147 | + TorchVision detection models expect List[Tensor] where each tensor is: |
| 148 | + - shape: [3, H, W] |
| 149 | + - dtype: float32 |
| 150 | + - values: [0..1] |
142 | 151 | """ |
| 152 | + del inference_args |
| 153 | + del model_id |
| 154 | + |
143 | 155 | with torch.no_grad(): |
| 156 | + # Ensure tensors are on device |
144 | 157 | inputs = [] |
145 | 158 | for t in batch: |
146 | 159 | if isinstance(t, torch.Tensor): |
|
0 commit comments