Skip to content

Commit e687e76

Browse files
author
spencer@primus
committed
Add type hinting for dataset evaluation
1 parent 116464a commit e687e76

4 files changed

Lines changed: 21 additions & 33 deletions

File tree

avapi/_dataset.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
from avstack import calibration, sensors
8+
from avstack.datastructs import DataContainer
89
from avstack.environment.objects import (
910
ObjectState,
1011
ObjectStateDecoder,
@@ -267,40 +268,41 @@ def get_objects(
267268
agent=None,
268269
max_dist=None,
269270
max_occ=None,
271+
in_global: bool = False,
270272
**kwargs,
271-
):
273+
) -> DataContainer:
272274
reference = self.get_ego_reference(frame, agent=agent)
273275
sensor = self.get_sensor_name(sensor, agent=agent)
276+
timestamp = self.get_timestamp(frame=frame, sensor=sensor, agent=agent)
274277
objs = self._load_objects(frame, sensor=sensor, agent=agent, **kwargs)
275278
if max_occ is not None:
276-
objs = np.array(
277-
[
279+
objs = [
278280
obj
279281
for obj in objs
280282
if (obj.occlusion <= max_occ)
281283
or (obj.occlusion == Occlusion.UNKNOWN)
282-
]
283-
)
284+
]
284285
if max_dist is not None:
285286
if sensor == "ego":
286287
calib = calibration.Calibration(reference)
287288
else:
288289
calib = self.get_calibration(frame, sensor, agent=agent)
289-
objs = np.array(
290-
[
290+
objs = [
291291
obj
292292
for obj in objs
293293
if obj.position.distance(calib.reference) < max_dist
294-
]
295-
)
294+
]
295+
objs = DataContainer(source_identifier=sensor, frame=frame, timestamp=timestamp, data=objs)
296+
if in_global:
297+
objs = objs.apply_and_return("change_reference", GlobalOrigin3D, inplace=False)
296298
return objs
297299

298300
def get_objects_global(
299301
self,
300302
frame,
301303
max_dist: Union[Tuple[ReferenceFrame, float], None] = None,
302304
**kwargs,
303-
):
305+
) -> DataContainer:
304306
return self._load_objects_global(frame, max_dist=max_dist, **kwargs)
305307

306308
def get_number_of_objects(self, frame, **kwargs):

avapi/evaluation/__init__.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
# -*- coding: utf-8 -*-
2-
# @Author: spencer@primus
3-
# @Date: 2022-05-30
4-
# @Last Modified by: spencer@primus
5-
# @Last Modified time: 2022-09-12
6-
7-
81
from . import metrics
92
from .base import ResultManager
103
from .perception import (
@@ -14,3 +7,5 @@
147
from .prediction import get_predict_results_from_folder
158
from .tracking import get_track_results_from_folder, get_track_results_from_multi_folder
169
from .trades import run_trades
10+
11+

avapi/evaluation/ospa.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def _cost(list_shorter: list, list_longer: list, p: float = 1.0, c: float = 1.0)
2121
return distance
2222

2323
@staticmethod
24-
def cost(tracks: list, truths: list):
24+
def cost(tracks: list, truths: list, p: float = 1.0, c: float = 1.0):
2525
if len(tracks) <= len(truths):
26-
return OspaMetric._cost(tracks, truths)
26+
return OspaMetric._cost(tracks, truths, p=p, c=c)
2727
else:
28-
return OspaMetric._cost(truths, tracks)
28+
return OspaMetric._cost(truths, tracks, p=p, c=c)

avapi/nuscenes/dataset.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,3 @@
1-
# -*- coding: utf-8 -*-
2-
# @Author: Spencer H
3-
# @Date: 2022-09-05
4-
# @Last Modified by: Spencer H
5-
# @Last Modified date: 2022-09-29
6-
# @Description:
7-
"""
8-
9-
"""
101
import logging
112
import os
123
import struct
@@ -59,18 +50,18 @@ def __init__(self, data_dir, split="v1.0-mini", verbose=False):
5950
for k, vs in splits_scenes.items()
6051
}
6152

62-
def list_scenes(self):
53+
def list_scenes(self) -> List[str]:
6354
self.nuX.list_scenes()
6455

65-
def get_scene_dataset_by_name(self, scene_name):
56+
def get_scene_dataset_by_name(self, scene_name) -> "nuScenesSceneDataset":
6657
idx = self.scene_name_to_index[scene_name]
6758
return self.get_scene_dataset_by_index(idx)
6859

69-
def get_scene_dataset_by_scene_number(self, scene_number):
60+
def get_scene_dataset_by_scene_number(self, scene_number) -> "nuScenesSceneDataset":
7061
idx = self.scene_number_to_index[scene_number]
7162
return self.get_scene_dataset_by_index(idx)
7263

73-
def get_scene_dataset_by_index(self, scene_idx):
64+
def get_scene_dataset_by_index(self, scene_idx) -> "nuScenesSceneDataset":
7465
return nuScenesSceneDataset(
7566
self.data_dir,
7667
self.split,

0 commit comments

Comments
 (0)