1+ from __future__ import absolute_import
2+ from __future__ import division
3+ from __future__ import print_function
4+
5+ from keras .optimizers import SGD
6+ from keras .optimizers import Adam
7+ from keras import backend as K
8+ K .set_image_dim_ordering ('tf' )
9+ import socket
10+ import os
11+
12+ # -------------------------------------------------
13+ # Background config:
14+ hostname = socket .gethostname ()
15+ if hostname == 'baymax' :
16+ path_var = 'baymax/'
17+ elif hostname == 'walle' :
18+ path_var = 'walle/'
19+ elif hostname == 'bender' :
20+ path_var = 'bender/'
21+ else :
22+ path_var = 'zhora/'
23+
24+ DATA_DIR = '/local_home/JAAD_Dataset/iros/resized_imgs_128/train/'
25+ # DATA_DIR= '/local_home/data/KITTI_data/'
26+
27+ TEST_DATA_DIR = '/local_home/JAAD_Dataset/iros/resized_imgs_128/test/'
28+
29+ VAL_DATA_DIR = '/local_home/JAAD_Dataset/iros/resized_imgs_128/val/'
30+
31+ PRETRAINED_C3D = '/home/pratik/git_projects/c3d-keras/models/sports1M_weights_tf.json'
32+ PRETRAINED_C3D_WEIGHTS = '/home/pratik/git_projects/c3d-keras/models/sports1M_weights_tf.h5'
33+
34+ MODEL_DIR = './../' + path_var + 'models'
35+ if not os .path .exists (MODEL_DIR ):
36+ os .mkdir (MODEL_DIR )
37+
38+ CHECKPOINT_DIR = './../' + path_var + 'checkpoints'
39+ if not os .path .exists (CHECKPOINT_DIR ):
40+ os .mkdir (CHECKPOINT_DIR )
41+
42+ GEN_IMAGES_DIR = './../' + path_var + 'generated_images'
43+ if not os .path .exists (GEN_IMAGES_DIR ):
44+ os .mkdir (GEN_IMAGES_DIR )
45+
46+ CLA_GEN_IMAGES_DIR = GEN_IMAGES_DIR + '/cla_gen/'
47+ if not os .path .exists (CLA_GEN_IMAGES_DIR ):
48+ os .mkdir (CLA_GEN_IMAGES_DIR )
49+
50+ ATTN_WEIGHTS_DIR = './../' + path_var + 'attn_weights'
51+ if not os .path .exists (ATTN_WEIGHTS_DIR ):
52+ os .mkdir (ATTN_WEIGHTS_DIR )
53+
54+ LOG_DIR = './../' + path_var + 'logs'
55+ if not os .path .exists (LOG_DIR ):
56+ os .mkdir (LOG_DIR )
57+
58+ TF_LOG_DIR = './../' + path_var + 'tf_logs'
59+ if not os .path .exists (TF_LOG_DIR ):
60+ os .mkdir (TF_LOG_DIR )
61+
62+ TF_LOG_CLA_DIR = './../' + path_var + 'tf_cla_logs'
63+ if not os .path .exists (TF_LOG_CLA_DIR ):
64+ os .mkdir (TF_LOG_CLA_DIR )
65+
66+ TEST_RESULTS_DIR = './../' + path_var + 'test_results'
67+ if not os .path .exists (TEST_RESULTS_DIR ):
68+ os .mkdir (TEST_RESULTS_DIR )
69+
70+ PRINT_MODEL_SUMMARY = True
71+ SAVE_MODEL = True
72+ PLOT_MODEL = True
73+ SAVE_GENERATED_IMAGES = True
74+ SHUFFLE = True
75+ VIDEO_LENGTH = 32
76+ IMG_SIZE = (128 , 128 , 3 )
77+ VIS_ATTN = True
78+ ATTN_COEFF = 0
79+ # KL coeff damages learning
80+ KL_COEFF = 0
81+ CLASSIFIER = True
82+ RAM_DECIMATE = True
83+ RETRAIN_CLASSIFIER = False
84+ CLASS_TARGET_INDEX = 24
85+ LOSS_WEIGHTS = [1. , 0.5 ]
86+
87+ ped_actions = ['slow down' , 'moving slow' , 'standing' , 'stopped' ,
88+ 'speed up' , 'moving fast' , 'look' , 'looking' , 'clear path' ,
89+ 'crossing' , 'nod' , 'handwave' , 'unknown' ]
90+
91+ simple_ped_set = ['crossing' , 'stopped' , 'looking' , 'clear path' , 'unknown' ]
92+
93+ driver_actions = ['moving slow' , 'slowing down' , 'standing' , 'speeding up' , 'moving fast' ]
94+ simple_driver_set = ['slow down' , 'stop' , 'speed up' ]
95+
96+ joint_action_set = ['moving slow' , 'slowing down' , 'standing' , 'speeding up' , 'moving fast' ,
97+ 'slow down' , 'standing' , 'moving fast' , 'speed up' , 'look' , 'nod' , 'unknown' ,
98+ 'moving slow' , 'flasher signal' , 'looking' , 'handwave' , 'clear path' ,
99+ 'stopped' , 'slowing down' , 'crossing' , 'speeding up' ]
100+
101+ formatted_joint_action_set = ['car moving slow' , 'car slowing down' , 'car standing' , 'car speeding up' , 'car moving fast' ,
102+ 'ped slow down' , 'ped standing' , 'ped moving fast' , 'ped speed up' , 'ped look' ,
103+ 'ped nod' , 'ped unknown' , 'ped moving slow' , 'ped flasher signal' , 'ped looking' ,
104+ 'ped handwave' , 'ped clear path' , 'ped stopped' , 'ped slowing down' , 'ped crossing' ,
105+ 'ped speeding up' ]
106+
107+
108+ # -------------------------------------------------
109+ # Network configuration:
110+ print ("Loading network/training configuration..." )
111+ print ("Config file: " + str (__name__ ))
112+
113+ BATCH_SIZE = 9
114+ NB_EPOCHS_AUTOENCODER = 0
115+ NB_EPOCHS_CLASS = 100
116+
117+ OPTIM_A = Adam (lr = 0.0001 , beta_1 = 0.5 )
118+ # OPTIM_C = Adam(lr=0.0000002, beta_1=0.5)
119+ OPTIM_C = SGD (lr = 0.0001 , momentum = 0.9 , nesterov = True )
120+
121+
122+ auto_lr_schedule = [25 , 30 , 35 ] # epoch_step
123+ def auto_schedule (epoch_idx ):
124+ if (epoch_idx + 1 ) < auto_lr_schedule [0 ]:
125+ return 0.0001
126+ elif (epoch_idx + 1 ) < auto_lr_schedule [1 ]:
127+ return 0.00001 # lr_decay_ratio = 10
128+ elif (epoch_idx + 1 ) < auto_lr_schedule [2 ]:
129+ return 0.00001
130+ return 0.00001
131+
132+
133+ clas_lr_schedule = [10 , 11 , 12 ] # epoch_step
134+ def clas_schedule (epoch_idx ):
135+ if (epoch_idx + 1 ) < clas_lr_schedule [0 ]:
136+ return 0.0001
137+ elif (epoch_idx + 1 ) < clas_lr_schedule [1 ]:
138+ return 0.00001 # lr_decay_ratio = 10
139+ elif (epoch_idx + 1 ) < clas_lr_schedule [2 ]:
140+ return 0.000001
141+ return 0.0000001
0 commit comments