|
| 1 | +from __future__ import absolute_import |
| 2 | +from __future__ import division |
| 3 | +from __future__ import print_function |
| 4 | + |
| 5 | +from keras.optimizers import SGD |
| 6 | +from keras.optimizers import Adam |
| 7 | +from keras.optimizers import adadelta |
| 8 | +from keras.optimizers import rmsprop |
| 9 | +from keras.layers import Layer |
| 10 | +from keras import backend as K |
| 11 | +K.set_image_dim_ordering('tf') |
| 12 | +import socket |
| 13 | +import os |
| 14 | + |
| 15 | +# ------------------------------------------------- |
| 16 | +# Background config: |
| 17 | +hostname = socket.gethostname() |
| 18 | +if hostname == 'baymax': |
| 19 | + path_var = 'baymax/' |
| 20 | +elif hostname == 'walle': |
| 21 | + path_var = 'walle/' |
| 22 | +elif hostname == 'bender': |
| 23 | + path_var = 'bender/' |
| 24 | +else: |
| 25 | + path_var = 'zhora/' |
| 26 | + |
| 27 | +DATA_DIR= '/local_home/JAAD_Dataset/iros/resized_imgs_208_thesis/train/' |
| 28 | + |
| 29 | +VAL_DATA_DIR= '/local_home/JAAD_Dataset/iros/resized_imgs_208_thesis/val/' |
| 30 | + |
| 31 | +TEST_DATA_DIR= '/local_home/JAAD_Dataset/iros/resized_imgs_208_thesis/test/' |
| 32 | +# TEST_DATA_DIR= '/local_home/JAAD_Dataset/fun_experiments/resized/' |
| 33 | + |
| 34 | +RESULTS_DIR = '/local_home/JAAD_Dataset/thesis/results/NRNN16/' |
| 35 | + |
| 36 | +# MODEL_DIR = './../' + path_var + 'models' |
| 37 | +MODEL_DIR = RESULTS_DIR + 'models/' |
| 38 | +if not os.path.exists(MODEL_DIR): |
| 39 | + os.mkdir(MODEL_DIR) |
| 40 | + |
| 41 | +# CHECKPOINT_DIR = './../' + path_var + 'checkpoints' |
| 42 | +CHECKPOINT_DIR = RESULTS_DIR + 'checkpoints/' |
| 43 | +if not os.path.exists(CHECKPOINT_DIR): |
| 44 | + os.mkdir(CHECKPOINT_DIR) |
| 45 | + |
| 46 | +GEN_IMAGES_DIR = RESULTS_DIR + 'generated_images/' |
| 47 | +if not os.path.exists(GEN_IMAGES_DIR): |
| 48 | + os.mkdir(GEN_IMAGES_DIR) |
| 49 | + |
| 50 | +LOG_DIR = RESULTS_DIR + 'logs/' |
| 51 | +if not os.path.exists(LOG_DIR): |
| 52 | + os.mkdir(LOG_DIR) |
| 53 | + |
| 54 | +TF_LOG_DIR = RESULTS_DIR + 'tf_logs/' |
| 55 | +if not os.path.exists(TF_LOG_DIR): |
| 56 | + os.mkdir(TF_LOG_DIR) |
| 57 | + |
| 58 | +TEST_RESULTS_DIR = RESULTS_DIR + 'test_results/' |
| 59 | +if not os.path.exists(TEST_RESULTS_DIR): |
| 60 | + os.mkdir(TEST_RESULTS_DIR) |
| 61 | + |
| 62 | +PRINT_MODEL_SUMMARY = True |
| 63 | +SAVE_MODEL = True |
| 64 | +PLOT_MODEL = True |
| 65 | +SAVE_GENERATED_IMAGES = True |
| 66 | +SHUFFLE = True |
| 67 | +VIDEO_LENGTH = 32 |
| 68 | +IMG_SIZE = (128, 208, 3) |
| 69 | +RAM_DECIMATE = False |
| 70 | +REVERSE = True |
| 71 | +FILTER_SIZE = 3 |
| 72 | + |
| 73 | + |
| 74 | +# ------------------------------------------------- |
| 75 | +# Network configuration: |
| 76 | +print ("Loading network/training configuration.") |
| 77 | +print ("Config file: " + str(__name__)) |
| 78 | + |
| 79 | +BATCH_SIZE = 9 |
| 80 | +TEST_BATCH_SIZE = 1 |
| 81 | +NB_EPOCHS_AUTOENCODER = 30 |
| 82 | + |
| 83 | +# OPTIM_A = Adam(lr=0.0001, beta_1=0.5) |
| 84 | +OPTIM_A = rmsprop(lr=0.0001, rho=0.9) |
| 85 | +OPTIM_B = rmsprop(lr=0.00001, rho=0.9) |
| 86 | +# OPTIM_A = SGD(lr=0.000001, momentum=0.5, nesterov=True) |
| 87 | + |
| 88 | +lr_schedule = [7, 14, 20, 30] # epoch_step |
| 89 | + |
| 90 | +def schedule(epoch_idx): |
| 91 | + if (epoch_idx) <= lr_schedule[0]: |
| 92 | + return 0.001 |
| 93 | + elif (epoch_idx) <= lr_schedule[1]: |
| 94 | + return 0.0001 # lr_decay_ratio = 10 |
| 95 | + elif (epoch_idx) <= lr_schedule[2]: |
| 96 | + return 0.00001 # lr_decay_ratio = 10 |
| 97 | + elif (epoch_idx) <= lr_schedule[3]: |
| 98 | + return 0.00001 |
| 99 | + return 0.00001 |
| 100 | + |
| 101 | + |
| 102 | + |
0 commit comments