1+ from __future__ import absolute_import
2+ from __future__ import division
3+ from __future__ import print_function
4+
5+ from keras .optimizers import SGD
6+ from keras .optimizers import Adam
7+ from keras .optimizers import RMSprop
8+ from keras import backend as K
9+ K .set_image_dim_ordering ('tf' )
10+ import socket
11+ import os
12+
13+ # -------------------------------------------------
14+ # Background config:
15+ hostname = socket .gethostname ()
16+ if hostname == 'baymax' :
17+ path_var = 'baymax/'
18+ elif hostname == 'walle' :
19+ path_var = 'walle/'
20+ elif hostname == 'bender' :
21+ path_var = 'bender/'
22+ else :
23+ path_var = 'zhora/'
24+
25+ DATA_DIR = '/local_home/JAAD_Dataset/iros/resized_imgs_208_sorted/train/'
26+ # DATA_DIR= '/local_home/data/KITTI_data/'
27+
28+ TEST_DATA_DIR = '/local_home/JAAD_Dataset/iros/resized_imgs_208_sorted/test/'
29+
30+ VAL_DATA_DIR = '/local_home/JAAD_Dataset/iros/resized_imgs_208_sorted/val/'
31+
32+ PRETRAINED_C3D = '/home/pratik/git_projects/c3d-keras/models/sports1M_weights_tf.json'
33+ PRETRAINED_C3D_WEIGHTS = '/home/pratik/git_projects/c3d-keras/models/sports1M_weights_tf.h5'
34+
35+ MODEL_DIR = './../' + path_var + 'models'
36+ if not os .path .exists (MODEL_DIR ):
37+ os .mkdir (MODEL_DIR )
38+
39+ CHECKPOINT_DIR = './../' + path_var + 'checkpoints'
40+ if not os .path .exists (CHECKPOINT_DIR ):
41+ os .mkdir (CHECKPOINT_DIR )
42+
43+ GEN_IMAGES_DIR = './../' + path_var + 'generated_images'
44+ if not os .path .exists (GEN_IMAGES_DIR ):
45+ os .mkdir (GEN_IMAGES_DIR )
46+
47+ CLA_GEN_IMAGES_DIR = GEN_IMAGES_DIR + '/cla_gen/'
48+ if not os .path .exists (CLA_GEN_IMAGES_DIR ):
49+ os .mkdir (CLA_GEN_IMAGES_DIR )
50+
51+ # ATTN_WEIGHTS_DIR = './../' + path_var + 'attn_weights'
52+ # if not os.path.exists(ATTN_WEIGHTS_DIR):
53+ # os.mkdir(ATTN_WEIGHTS_DIR)
54+
55+ LOG_DIR = './../' + path_var + 'logs'
56+ if not os .path .exists (LOG_DIR ):
57+ os .mkdir (LOG_DIR )
58+
59+ TF_LOG_DIR = './../' + path_var + 'tf_logs'
60+ if not os .path .exists (TF_LOG_DIR ):
61+ os .mkdir (TF_LOG_DIR )
62+
63+ TF_LOG_CLA_DIR = './../' + path_var + 'tf_cla_logs'
64+ if not os .path .exists (TF_LOG_CLA_DIR ):
65+ os .mkdir (TF_LOG_CLA_DIR )
66+
67+ TEST_RESULTS_DIR = './../' + path_var + 'test_results'
68+ if not os .path .exists (TEST_RESULTS_DIR ):
69+ os .mkdir (TEST_RESULTS_DIR )
70+
71+ PRINT_MODEL_SUMMARY = True
72+ SAVE_MODEL = True
73+ PLOT_MODEL = False
74+ SAVE_GENERATED_IMAGES = True
75+ SHUFFLE = True
76+ VIDEO_LENGTH = 16
77+ IMG_SIZE = (128 , 208 , 3 )
78+ VIS_ATTN = True
79+ CLASSIFIER = True
80+ BUF_SIZE = 10
81+ LOSS_WEIGHTS = [1 , 1 ]
82+ A_TRAIN_RATIO = 1
83+ C_TRAIN_RATIO = 1
84+ RAM_DECIMATE = True
85+ RETRAIN_CLASSIFIER = True
86+ CLASS_TARGET_INDEX = 8
87+ ROT_MAX = 10
88+ SFT_H_MAX = 0.05
89+ SFT_V_MAX = 0.05
90+ ZOOM_MAX = 0.2
91+ BRIGHT_RANGE_L = 0.5
92+ BRIGHT_RANGE_H = 1.5
93+
94+ ped_actions = ['slow down' , 'standing' , 'walking' , 'speed up' , 'nod' , 'unknown' ,
95+ 'clear path' , 'handwave' , 'crossing' , 'looking' , 'no ped' ]
96+
97+ simple_ped_set = ['standing' ,'crossing' , 'no ped' ]
98+
99+
100+
101+
102+
103+ driver_actions = ['moving slow' , 'slowing down' , 'standing' , 'speeding up' , 'moving fast' ]
104+ simple_driver_set = ['slow down' , 'stop' , 'speed up' ]
105+
106+ joint_action_set = ['moving slow' , 'slowing down' , 'standing' , 'speeding up' , 'moving fast' ,
107+ 'slow down' , 'standing' , 'moving fast' , 'speed up' , 'look' , 'nod' , 'unknown' ,
108+ 'moving slow' , 'flasher signal' , 'looking' , 'handwave' , 'clear path' ,
109+ 'stopped' , 'slowing down' , 'crossing' , 'speeding up' ]
110+
111+ formatted_joint_action_set = ['car moving slow' , 'car slowing down' , 'car standing' , 'car speeding up' , 'car moving fast' ,
112+ 'ped slow down' , 'ped standing' , 'ped moving fast' , 'ped speed up' , 'ped look' ,
113+ 'ped nod' , 'ped unknown' , 'ped moving slow' , 'ped flasher signal' , 'ped looking' ,
114+ 'ped handwave' , 'ped clear path' , 'ped stopped' , 'ped slowing down' , 'ped crossing' ,
115+ 'ped speeding up' ]
116+
117+
118+ # -------------------------------------------------
119+ # Network configuration:
120+ print ("Loading network/training configuration..." )
121+ print ("Config file: " + str (__name__ ))
122+
123+ BATCH_SIZE = 25
124+ NB_EPOCHS_CLASS = 100
125+
126+ OPTIM_C = Adam (lr = 0.0000002 , beta_1 = 0.5 )
127+ # OPTIM_C = SGD(lr=0.0001, momentum=0.9, nesterov=True)
128+ # OPTIM_C = RMSprop(lr=0.0001, rho=0.9)
129+
130+ # lr_schedule = [10, 20, 30] # epoch_step
131+
132+ # def schedule(epoch_idx):
133+ # if (epoch_idx + 1) < lr_schedule[0]:
134+ # return 0.00000001
135+ # elif (epoch_idx + 1) < lr_schedule[1]:
136+ # return 0.000000001 # lr_decay_ratio = 10
137+ # elif (epoch_idx + 1) < lr_schedule[2]:
138+ # return 0.000000001
139+ # return 0.000000001
140+
141+
142+ lr_schedule = [7 , 15 , 22 ] # epoch_step
143+
144+
145+ def schedule (epoch_idx ):
146+ if (epoch_idx + 1 ) < lr_schedule [0 ]:
147+ return 0.00001
148+ elif (epoch_idx + 1 ) < lr_schedule [1 ]:
149+ return 0.000001 # lr_decay_ratio = 10
150+ elif (epoch_idx + 1 ) < lr_schedule [2 ]:
151+ return 0.0000001
152+ return 0.0000001
0 commit comments