Skip to content

Commit 27731e8

Browse files
committed
Trasition sequence training
1 parent 3a431b7 commit 27731e8

22 files changed

Lines changed: 898 additions & 862 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ code/autoencoder_model/walle/*
2222
code/autoencoder_model/gifs/imgs/*
2323
code/autoencoder_model/zhora/history/*
2424
code/autoencoder_model/bender/*
25+
code/autoencoder_model/gifs/*
353 KB
Loading
364 KB
Loading
-207 KB
Binary file not shown.

code/autoencoder_model/scripts/attention_autoencoder.py

Lines changed: 106 additions & 188 deletions
Large diffs are not rendered by default.

code/autoencoder_model/scripts/classifier.py

Lines changed: 161 additions & 135 deletions
Large diffs are not rendered by default.

code/autoencoder_model/scripts/config_aa.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929

3030
VAL_DATA_DIR= '/local_home/JAAD_Dataset/iros/resized_imgs_128/val/'
3131

32-
TEST_DATA_DIR= '/local_home/JAAD_Dataset/iros/resized_imgs_128/test/'
32+
# TEST_DATA_DIR= '/local_home/JAAD_Dataset/iros/resized_imgs_128/test/'
33+
TEST_DATA_DIR= '/local_home/JAAD_Dataset/fun_experiments/resized/'
3334

3435
MODEL_DIR = './../' + path_var + 'models'
3536
if not os.path.exists(MODEL_DIR):
@@ -64,17 +65,18 @@
6465
PLOT_MODEL = True
6566
SAVE_GENERATED_IMAGES = True
6667
SHUFFLE = True
67-
VIDEO_LENGTH = 20
68+
VIDEO_LENGTH = 32
6869
IMG_SIZE = (128, 128, 3)
6970
ATTN_COEFF = 0
7071
KL_COEFF = 0
72+
RAM_DECIMATE = True
7173

7274
# -------------------------------------------------
7375
# Network configuration:
7476
print ("Loading network/training configuration.")
7577
print ("Config file: " + str(__name__))
7678

77-
BATCH_SIZE = 1
79+
BATCH_SIZE = 20
7880
NB_EPOCHS_AUTOENCODER = 40
7981

8082
OPTIM_A = Adam(lr=0.0001, beta_1=0.5)
@@ -87,9 +89,9 @@ def schedule(epoch_idx):
8789
if (epoch_idx + 1) < lr_schedule[0]:
8890
return 0.0001
8991
elif (epoch_idx + 1) < lr_schedule[1]:
90-
return 0.00001 # lr_decay_ratio = 10
92+
return 0.0001 # lr_decay_ratio = 10
9193
elif (epoch_idx + 1) < lr_schedule[2]:
92-
return 0.000001
93-
return 0.000001
94+
return 0.0001
95+
return 0.0001
9496

9597

4 Bytes
Binary file not shown.

code/autoencoder_model/scripts/config_classifier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
SAVE_GENERATED_IMAGES = True
7474
SHUFFLE = True
7575
VIDEO_LENGTH = 32
76-
IMG_SIZE = (112, 112, 3)
76+
IMG_SIZE = (128, 128, 3)
7777
VIS_ATTN = True
7878
ATTN_COEFF = 0
7979
# KL coeff damages learning
@@ -129,12 +129,12 @@ def auto_schedule(epoch_idx):
129129
return 0.00001
130130

131131

132-
clas_lr_schedule = [50, 55, 60] # epoch_step
132+
clas_lr_schedule = [10, 11, 12] # epoch_step
133133
def clas_schedule(epoch_idx):
134134
if (epoch_idx + 1) < clas_lr_schedule[0]:
135135
return 0.0001
136136
elif (epoch_idx + 1) < clas_lr_schedule[1]:
137137
return 0.00001 # lr_decay_ratio = 10
138138
elif (epoch_idx + 1) < clas_lr_schedule[2]:
139139
return 0.000001
140-
return 0.000001
140+
return 0.0000001
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
from keras.optimizers import SGD
6+
from keras.optimizers import Adam
7+
from keras import backend as K
8+
K.set_image_dim_ordering('tf')
9+
import socket
10+
import os
11+
12+
# -------------------------------------------------
13+
# Background config:
14+
hostname = socket.gethostname()
15+
if hostname == 'baymax':
16+
path_var = 'baymax/'
17+
elif hostname == 'walle':
18+
path_var = 'walle/'
19+
elif hostname == 'bender':
20+
path_var = 'bender/'
21+
else:
22+
path_var = 'zhora/'
23+
24+
DATA_DIR= '/local_home/JAAD_Dataset/iros/resized_imgs_128/train/'
25+
# DATA_DIR= '/local_home/data/KITTI_data/'
26+
27+
TEST_DATA_DIR= '/local_home/JAAD_Dataset/iros/resized_imgs_128/test/'
28+
29+
VAL_DATA_DIR= '/local_home/JAAD_Dataset/iros/resized_imgs_128/val/'
30+
31+
PRETRAINED_C3D= '/home/pratik/git_projects/c3d-keras/models/sports1M_weights_tf.json'
32+
PRETRAINED_C3D_WEIGHTS= '/home/pratik/git_projects/c3d-keras/models/sports1M_weights_tf.h5'
33+
34+
MODEL_DIR = './../' + path_var + 'models'
35+
if not os.path.exists(MODEL_DIR):
36+
os.mkdir(MODEL_DIR)
37+
38+
CHECKPOINT_DIR = './../' + path_var + 'checkpoints'
39+
if not os.path.exists(CHECKPOINT_DIR):
40+
os.mkdir(CHECKPOINT_DIR)
41+
42+
GEN_IMAGES_DIR = './../' + path_var + 'generated_images'
43+
if not os.path.exists(GEN_IMAGES_DIR):
44+
os.mkdir(GEN_IMAGES_DIR)
45+
46+
CLA_GEN_IMAGES_DIR = GEN_IMAGES_DIR + '/cla_gen/'
47+
if not os.path.exists(CLA_GEN_IMAGES_DIR):
48+
os.mkdir(CLA_GEN_IMAGES_DIR)
49+
50+
ATTN_WEIGHTS_DIR = './../' + path_var + 'attn_weights'
51+
if not os.path.exists(ATTN_WEIGHTS_DIR):
52+
os.mkdir(ATTN_WEIGHTS_DIR)
53+
54+
LOG_DIR = './../' + path_var + 'logs'
55+
if not os.path.exists(LOG_DIR):
56+
os.mkdir(LOG_DIR)
57+
58+
TF_LOG_DIR = './../' + path_var + 'tf_logs'
59+
if not os.path.exists(TF_LOG_DIR):
60+
os.mkdir(TF_LOG_DIR)
61+
62+
TF_LOG_CLA_DIR = './../' + path_var + 'tf_cla_logs'
63+
if not os.path.exists(TF_LOG_CLA_DIR):
64+
os.mkdir(TF_LOG_CLA_DIR)
65+
66+
TEST_RESULTS_DIR = './../' + path_var + 'test_results'
67+
if not os.path.exists(TEST_RESULTS_DIR):
68+
os.mkdir(TEST_RESULTS_DIR)
69+
70+
PRINT_MODEL_SUMMARY = True
71+
SAVE_MODEL = True
72+
PLOT_MODEL = True
73+
SAVE_GENERATED_IMAGES = True
74+
SHUFFLE = True
75+
VIDEO_LENGTH = 32
76+
IMG_SIZE = (128, 128, 3)
77+
VIS_ATTN = True
78+
ATTN_COEFF = 0
79+
# KL coeff damages learning
80+
KL_COEFF = 0
81+
CLASSIFIER = True
82+
RAM_DECIMATE = True
83+
RETRAIN_CLASSIFIER = False
84+
CLASS_TARGET_INDEX = 24
85+
LOSS_WEIGHTS = [1., 0.5]
86+
87+
ped_actions = ['slow down', 'moving slow', 'standing', 'stopped',
88+
'speed up', 'moving fast', 'look', 'looking', 'clear path',
89+
'crossing', 'nod', 'handwave', 'unknown']
90+
91+
simple_ped_set = ['crossing', 'stopped', 'looking', 'clear path', 'unknown']
92+
93+
driver_actions = ['moving slow', 'slowing down', 'standing', 'speeding up', 'moving fast']
94+
simple_driver_set = ['slow down', 'stop', 'speed up']
95+
96+
joint_action_set = ['moving slow', 'slowing down', 'standing', 'speeding up', 'moving fast',
97+
'slow down', 'standing', 'moving fast', 'speed up', 'look', 'nod', 'unknown',
98+
'moving slow', 'flasher signal', 'looking' , 'handwave', 'clear path',
99+
'stopped', 'slowing down', 'crossing', 'speeding up']
100+
101+
formatted_joint_action_set = ['car moving slow', 'car slowing down', 'car standing', 'car speeding up', 'car moving fast',
102+
'ped slow down', 'ped standing', 'ped moving fast', 'ped speed up', 'ped look',
103+
'ped nod', 'ped unknown', 'ped moving slow', 'ped flasher signal', 'ped looking' ,
104+
'ped handwave', 'ped clear path', 'ped stopped', 'ped slowing down', 'ped crossing',
105+
'ped speeding up']
106+
107+
108+
# -------------------------------------------------
109+
# Network configuration:
110+
print ("Loading network/training configuration...")
111+
print ("Config file: " + str(__name__))
112+
113+
BATCH_SIZE = 9
114+
NB_EPOCHS_AUTOENCODER = 0
115+
NB_EPOCHS_CLASS = 100
116+
117+
OPTIM_A = Adam(lr=0.0001, beta_1=0.5)
118+
# OPTIM_C = Adam(lr=0.0000002, beta_1=0.5)
119+
OPTIM_C = SGD(lr=0.0001, momentum=0.9, nesterov=True)
120+
121+
122+
auto_lr_schedule = [25, 30, 35] # epoch_step
123+
def auto_schedule(epoch_idx):
124+
if (epoch_idx + 1) < auto_lr_schedule[0]:
125+
return 0.0001
126+
elif (epoch_idx + 1) < auto_lr_schedule[1]:
127+
return 0.00001 # lr_decay_ratio = 10
128+
elif (epoch_idx + 1) < auto_lr_schedule[2]:
129+
return 0.00001
130+
return 0.00001
131+
132+
133+
clas_lr_schedule = [10, 11, 12] # epoch_step
134+
def clas_schedule(epoch_idx):
135+
if (epoch_idx + 1) < clas_lr_schedule[0]:
136+
return 0.0001
137+
elif (epoch_idx + 1) < clas_lr_schedule[1]:
138+
return 0.00001 # lr_decay_ratio = 10
139+
elif (epoch_idx + 1) < clas_lr_schedule[2]:
140+
return 0.000001
141+
return 0.0000001

0 commit comments

Comments
 (0)