1+ from __future__ import absolute_import
2+ from __future__ import division
3+ from __future__ import print_function
4+
5+ from keras .optimizers import SGD
6+ from keras .optimizers import Adam
7+ from keras .optimizers import RMSprop
8+ from keras import backend as K
9+ K .set_image_dim_ordering ('tf' )
10+ import socket
11+ import os
12+
13+ # -------------------------------------------------
14+ # Background config:
15+ hostname = socket .gethostname ()
16+ if hostname == 'baymax' :
17+ path_var = 'baymax/'
18+ elif hostname == 'walle' :
19+ path_var = 'walle/'
20+ elif hostname == 'bender' :
21+ path_var = 'bender/'
22+ else :
23+ path_var = 'zhora/'
24+
25+ DATA_DIR = '/local_home/JAAD_Dataset/iros/resized_imgs_208_sorted/train/'
26+ # DATA_DIR= '/local_home/data/KITTI_data/'
27+
28+ TEST_DATA_DIR = '/local_home/JAAD_Dataset/iros/resized_imgs_208_sorted/test/'
29+
30+ VAL_DATA_DIR = '/local_home/JAAD_Dataset/iros/resized_imgs_208_sorted/val/'
31+
32+ PRETRAINED_C3D = '/home/pratik/git_projects/c3d-keras/models/sports1M_weights_tf.json'
33+ PRETRAINED_C3D_WEIGHTS = '/home/pratik/git_projects/c3d-keras/models/sports1M_weights_tf.h5'
34+
35+ MODEL_DIR = './../' + path_var + 'models'
36+ if not os .path .exists (MODEL_DIR ):
37+ os .mkdir (MODEL_DIR )
38+
39+ CHECKPOINT_DIR = './../' + path_var + 'checkpoints'
40+ if not os .path .exists (CHECKPOINT_DIR ):
41+ os .mkdir (CHECKPOINT_DIR )
42+
43+ GEN_IMAGES_DIR = './../' + path_var + 'generated_images'
44+ if not os .path .exists (GEN_IMAGES_DIR ):
45+ os .mkdir (GEN_IMAGES_DIR )
46+
47+ CLA_GEN_IMAGES_DIR = GEN_IMAGES_DIR + '/cla_gen/'
48+ if not os .path .exists (CLA_GEN_IMAGES_DIR ):
49+ os .mkdir (CLA_GEN_IMAGES_DIR )
50+
51+ # ATTN_WEIGHTS_DIR = './../' + path_var + 'attn_weights'
52+ # if not os.path.exists(ATTN_WEIGHTS_DIR):
53+ # os.mkdir(ATTN_WEIGHTS_DIR)
54+
55+ LOG_DIR = './../' + path_var + 'logs'
56+ if not os .path .exists (LOG_DIR ):
57+ os .mkdir (LOG_DIR )
58+
59+ TF_LOG_DIR = './../' + path_var + 'tf_logs'
60+ if not os .path .exists (TF_LOG_DIR ):
61+ os .mkdir (TF_LOG_DIR )
62+
63+ TF_LOG_CLA_DIR = './../' + path_var + 'tf_cla_logs'
64+ if not os .path .exists (TF_LOG_CLA_DIR ):
65+ os .mkdir (TF_LOG_CLA_DIR )
66+
67+ TEST_RESULTS_DIR = './../' + path_var + 'test_results'
68+ if not os .path .exists (TEST_RESULTS_DIR ):
69+ os .mkdir (TEST_RESULTS_DIR )
70+
71+ PRINT_MODEL_SUMMARY = True
72+ SAVE_MODEL = True
73+ PLOT_MODEL = False
74+ SAVE_GENERATED_IMAGES = True
75+ SHUFFLE = True
76+ VIDEO_LENGTH = 16
77+ IMG_SIZE = (128 , 208 , 3 )
78+ VIS_ATTN = True
79+ CLASSIFIER = True
80+ BUF_SIZE = 10
81+ LOSS_WEIGHTS = [1 , 1 ]
82+ A_TRAIN_RATIO = 1
83+ C_TRAIN_RATIO = 1
84+ RAM_DECIMATE = True
85+ RETRAIN_CLASSIFIER = True
86+ CLASS_TARGET_INDEX = 8
87+ ROT_MAX = 5
88+ SFT_H_MAX = 0.02
89+ SFT_V_MAX = 0.02
90+ ZOOM_MAX = 0.2
91+ BRIGHT_RANGE_L = 0.5
92+ BRIGHT_RANGE_H = 1.5
93+
94+ ped_actions = ['slow down' , 'standing' , 'walking' , 'speed up' , 'nod' , 'unknown' ,
95+ 'clear path' , 'handwave' , 'crossing' , 'looking' , 'no ped' ]
96+
97+ simple_ped_set = ['standing' , 'crossing' , 'no ped' ]
98+
99+ # -------------------------------------------------
100+ # Network configuration:
101+ print ("Loading network/training configuration..." )
102+ print ("Config file: " + str (__name__ ))
103+
104+ BATCH_SIZE = 25
105+ NB_EPOCHS_CLASS = 100
106+
107+ OPTIM_C = Adam (lr = 0.0000002 , beta_1 = 0.5 )
108+ # OPTIM_C = SGD(lr=0.0001, momentum=0.9, nesterov=True)
109+ # OPTIM_C = RMSprop(lr=0.0001, rho=0.9)
110+
111+ # lr_schedule = [10, 20, 30] # epoch_step
112+
113+ # def schedule(epoch_idx):
114+ # if (epoch_idx + 1) < lr_schedule[0]:
115+ # return 0.00000001
116+ # elif (epoch_idx + 1) < lr_schedule[1]:
117+ # return 0.000000001 # lr_decay_ratio = 10
118+ # elif (epoch_idx + 1) < lr_schedule[2]:
119+ # return 0.000000001
120+ # return 0.000000001
121+
122+
123+ lr_schedule = [7 , 15 , 22 ] # epoch_step
124+
125+
126+ def schedule (epoch_idx ):
127+ if (epoch_idx + 1 ) < lr_schedule [0 ]:
128+ return 0.00001
129+ elif (epoch_idx + 1 ) < lr_schedule [1 ]:
130+ return 0.000001 # lr_decay_ratio = 10
131+ elif (epoch_idx + 1 ) < lr_schedule [2 ]:
132+ return 0.0000001
133+ return 0.000001
0 commit comments