Skip to content

Commit 3072893

Browse files
author
spencer@primus
committed
Add default split probs for carla
1 parent cf466ed commit 3072893

2 files changed

Lines changed: 10 additions & 4 deletions

File tree

avapi/_dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,15 @@ def list_scenes(self):
6262
def get_splits_scenes(self):
6363
return self.splits_scenes
6464

65-
def make_splits_scenes(self, seed=1, frac_train=0.7, frac_val=0.3):
65+
def make_splits_scenes(self, seed=1, frac_train=0.7, frac_val=0.3, frac_test=0.0):
6666
"""Split the scenes by hashing the experiment name and modding
6767
3:1 split using mod 4
6868
"""
6969
rng = random.Random(seed)
7070

71+
if not (frac_train + frac_val + frac_test) == 1.0:
72+
raise ValueError("Fractions must add to 1.0")
73+
7174
# first two we alternate just to have one
7275
splits_scenes = {"train": [], "val": [], "test": []}
7376
for i, scene in enumerate(self.scenes):

avapi/carla/dataset.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,21 @@ class CarlaScenesManager(BaseSceneManager):
6868
nominal_whitelist_types = _nominal_whitelist_types
6969
nominal_ignore_types = _nominal_ignore_types
7070

71-
def __init__(self, data_dir, split=None, verbose=False):
71+
def __init__(self, data_dir, verbose=False,
72+
split_fracs = {"train": 0.6, "val": 0.2, "test": 0.2}, seed: int = 1):
7273
"""
7374
data_dir: the base folder where all scenes are kept
7475
"""
7576
if not os.path.exists(data_dir):
7677
raise RuntimeError(f"Cannot find data dir at {data_dir}")
7778
self.data_dir = data_dir
78-
self.split = split
7979
self.verbose = verbose
8080
self.scenes = sorted(next(os.walk(data_dir))[1])
8181
self.splits_scenes = self.make_splits_scenes(
82-
seed=1, frac_train=0.7, frac_val=0.3
82+
seed=seed,
83+
frac_train=split_fracs["train"],
84+
frac_val=split_fracs["val"],
85+
frac_test=split_fracs["test"],
8386
)
8487

8588
def get_scene_dataset_by_index(self, scene_idx):

0 commit comments

Comments
 (0)