Skip to content

Commit 0974da5

Browse files
committed
Weird val and tests
1 parent 5838603 commit 0974da5

21 files changed

Lines changed: 1924 additions & 393 deletions

code/autoencoder_model/scripts/thesis_scripts/baseline_classifier.py

Lines changed: 229 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
from sklearn.metrics import classification_report
2929
from sklearn.metrics import precision_recall_fscore_support
3030
from sklearn.metrics import confusion_matrix
31+
from sklearn.metrics import precision_recall_curve
32+
from sklearn.metrics import average_precision_score
33+
from sklearn.metrics import accuracy_score
3134
from image_utils import random_rotation
3235
from image_utils import random_shift
3336
from image_utils import flip_axis
@@ -883,32 +886,250 @@ def test(CLA_WEIGHTS):
883886

884887
# then after each epoch
885888
avg_test_c_loss = np.mean(np.asarray(test_c_loss, dtype=np.float32), axis=0)
886-
887889
test_prec, test_rec, test_fbeta, test_support = get_sklearn_metrics(np.asarray(y_test_true),
888-
np.asarray(y_test_pred),
889-
avg='binary',
890-
pos_label=1)
890+
np.asarray(y_test_pred),
891+
avg='binary',
892+
pos_label=1)
891893
print("\nAvg test_c_loss: " + str(avg_test_c_loss))
892894
print("Test Prec: %.4f, Recall: %.4f, Fbeta: %.4f" % (test_prec, test_rec, test_fbeta))
893895

894-
print ("Classification Report")
896+
test_acc = accuracy_score(y_test_true, np.round(y_test_pred))
897+
print("Test Accuracy: %.4f" % (test_acc))
898+
899+
avg_prec = average_precision_score(y_test_true, y_test_pred)
900+
print("Average precision: %.4f" % (avg_prec))
901+
902+
precisions, recalls, thresholds = precision_recall_curve(y_test_true, y_test_pred)
903+
print("PR curve precisions: " + str(precisions))
904+
print("PR curve recalls: " + str(recalls))
905+
print("PR curve thresholds: " + str(thresholds))
906+
print("PR curve prec mean: %.4f" %(np.mean(precisions)))
907+
print("PR curve prec std: %.4f" %(np.std(precisions)))
908+
print("Number of thresholds: %.4f" %(len(thresholds)))
909+
910+
print("Classification Report")
895911
print(get_classification_report(np.asarray(y_test_true), np.asarray(y_test_pred)))
896912

897-
print ("Confusion matrix")
913+
print("Confusion matrix")
898914
tn, fp, fn, tp = confusion_matrix(y_test_true, np.round(y_test_pred)).ravel()
899-
print ("TN: %.2f, FP: %.2f, FN: %.2f, TP: %.2f" % (tn, fp, fn, tp))
915+
print("TN: %.2f, FP: %.2f, FN: %.2f, TP: %.2f" % (tn, fp, fn, tp))
900916

901917
print("Mean time taken to make " + str(NB_TEST_ITERATIONS) + " predictions: %f"
902918
% (np.mean(np.asarray(iter_endtime) - np.asarray(iter_starttime))))
903919
print("Standard Deviation %f"
904920
% (np.std(np.asarray(iter_endtime) - np.asarray(iter_starttime))))
905921

906-
print("Mean time taken to load and process " + str(NB_TEST_ITERATIONS) + " predictions: %f"
922+
print("Mean time taken to make load and process" + str(NB_TEST_ITERATIONS) + " predictions: %f"
907923
% (np.mean(np.asarray(iter_endtime) - np.asarray(iter_loadtime))))
908924
print("Standard Deviation %f"
909925
% (np.std(np.asarray(iter_endtime) - np.asarray(iter_loadtime))))
910926

911927

928+
def test_mtcp(CLA_WEIGHTS):
929+
930+
if not os.path.exists(TEST_RESULTS_DIR + '/pred/'):
931+
os.mkdir(TEST_RESULTS_DIR + '/pred/')
932+
933+
# Setup test
934+
test_frames_source = hkl.load(os.path.join(TEST_DATA_DIR, 'sources_test_208.hkl'))
935+
# test_videos_list = get_video_lists(frames_source=test_frames_source, stride=8, frame_skip=0)
936+
test_videos_list = get_video_lists(frames_source=test_frames_source, stride=16, frame_skip=0)
937+
# test_videos_list = get_video_lists(frames_source=test_frames_source, stride=16, frame_skip=2)
938+
# Load test action annotations
939+
test_action_labels = hkl.load(os.path.join(TEST_DATA_DIR, 'annotations_test_208.hkl'))
940+
test_ped_action_classes, test_ped_class_count = get_action_classes(test_action_labels, mode='sigmoid')
941+
print("Test Stats: " + str(test_ped_class_count))
942+
943+
# Build the Spatio-temporal Autoencoder
944+
print("Creating models.")
945+
# Build stacked classifier
946+
# classifier = pretrained_c3d()
947+
classifier = ensemble_c3d()
948+
# classifier = c3d_scratch()
949+
classifier.compile(loss="binary_crossentropy",
950+
optimizer=OPTIM_C,
951+
# metrics=[metric_precision, metric_recall, metric_mpca, 'accuracy'])
952+
metrics=['acc'])
953+
954+
run_utilities(classifier, CLA_WEIGHTS)
955+
956+
n_test_videos = test_videos_list.shape[0]
957+
958+
NB_TEST_ITERATIONS = int(n_test_videos / TEST_BATCH_SIZE)
959+
# NB_TEST_ITERATIONS = 5
960+
961+
# Setup TensorBoard Callback
962+
TC_cla = tb_callback.TensorBoard(log_dir=TF_LOG_CLA_DIR, histogram_freq=0, write_graph=False,
963+
write_images=False)
964+
LRS_clas = lrs_callback.LearningRateScheduler(schedule=schedule)
965+
LRS_clas.set_model(classifier)
966+
if CLASSIFIER:
967+
print("Testing Classifier...")
968+
# Run over test data
969+
print('')
970+
# Time to correct prediction
971+
tcp_list = []
972+
tcp_true_list = []
973+
tcp_pred_list = []
974+
y_test_pred = []
975+
y_test_true = []
976+
test_c_loss = []
977+
index = 0
978+
tcp = 1
979+
while index < NB_TEST_ITERATIONS:
980+
X, y = load_X_y(test_videos_list, index, TEST_DATA_DIR, test_ped_action_classes,
981+
batch_size=TEST_BATCH_SIZE)
982+
983+
y_past_class = y[:, 0]
984+
y_end_class = y[:,-1]
985+
986+
if y_end_class[0] == y_past_class[0]:
987+
index = index + 1
988+
continue
989+
else:
990+
stdout.write("\rIter: " + str(index) + "/" + str(NB_TEST_ITERATIONS - 1))
991+
stdout.flush()
992+
for fnum in range (int(VIDEO_LENGTH/2) + 1):
993+
994+
X, y = load_X_y(test_videos_list, index, TEST_DATA_DIR, test_ped_action_classes,
995+
batch_size=TEST_BATCH_SIZE)
996+
X_test = X
997+
998+
y_true_imgs = X[:, int(VIDEO_LENGTH / 2):]
999+
y_true_class = y[:, VIDEO_LENGTH - fnum - 1]
1000+
if y[:, 0] == y_true_class[0]:
1001+
break
1002+
1003+
if (fnum + 1 > 16):
1004+
tcp_pred_list.append(y_pred_class[0])
1005+
tcp_true_list.append(y_true_class[0])
1006+
break
1007+
1008+
y_pred_class = classifier.predict(X_test, verbose=0)
1009+
y_test_pred.extend(classifier.predict(X_test, verbose=0))
1010+
test_c_loss.append(classifier.test_on_batch(X_test, y_true_class))
1011+
y_test_true.extend(y_true_class)
1012+
1013+
test_ped_pred_class = classifier.predict(X_test, verbose=0)
1014+
# pred_seq = arrange_images(np.concatenate((X_train, predicted_images), axis=1))
1015+
pred_seq = arrange_images(X_test)
1016+
pred_seq = pred_seq * 127.5 + 127.5
1017+
1018+
# Save generated images to file
1019+
z = encoder.predict(X_test)
1020+
test_predicted_images = decoder.predict(z)
1021+
test_ped_pred_class = sclassifier.predict(X_test, verbose=0)
1022+
pred_seq = arrange_images(np.concatenate((X_test, test_predicted_images), axis=1))
1023+
pred_seq = pred_seq * 127.5 + 127.5
1024+
1025+
truth_image = arrange_images(y_true_imgs)
1026+
truth_image = truth_image * 127.5 + 127.5
1027+
1028+
font = cv2.FONT_HERSHEY_SIMPLEX
1029+
y_orig_classes = y[:, 0: int(VIDEO_LENGTH / 2)]
1030+
y_true_classes = y[:, int(VIDEO_LENGTH / 2):]
1031+
1032+
# Add labels as text to the image
1033+
for k in range(TEST_BATCH_SIZE):
1034+
for j in range(int(VIDEO_LENGTH / 2)):
1035+
if y_orig_classes[k, j] > 0.5:
1036+
label_orig = "crossing"
1037+
else:
1038+
label_orig = "not crossing"
1039+
1040+
if y_true_classes[k][j] > 0.5:
1041+
label_true = "crossing"
1042+
else:
1043+
label_true = "not crossing"
1044+
1045+
if test_ped_pred_class[k][0] > 0.5:
1046+
label_pred = "crossing"
1047+
else:
1048+
label_pred = "not crossing"
1049+
1050+
cv2.putText(pred_seq, label_orig,
1051+
(2 + j * (208), 114 + k * 128), font, 0.5, (255, 255, 255), 1,
1052+
cv2.LINE_AA)
1053+
cv2.putText(pred_seq, label_pred,
1054+
(2 + (j + 16) * (208), 114 + k * 128), font, 0.5, (255, 255, 255), 1,
1055+
cv2.LINE_AA)
1056+
cv2.putText(pred_seq, 'truth: ' + label_true,
1057+
(2 + (j + 16) * (208), 94 + k * 128), font, 0.5, (255, 255, 255), 1,
1058+
cv2.LINE_AA)
1059+
cv2.putText(truth_image, label_true,
1060+
(2 + j * (208), 114 + k * 128), font, 0.5, (255, 255, 255), 1,
1061+
cv2.LINE_AA)
1062+
1063+
cv2.imwrite(os.path.join(TEST_RESULTS_DIR + '/mtcp-pred//', str(index) + "_cla_test_pred.png"),
1064+
pred_seq)
1065+
cv2.imwrite(os.path.join(TEST_RESULTS_DIR + '/mtcp-truth/', str(index) + "_cla_test_truth.png"),
1066+
truth_image)
1067+
1068+
if y_true_class[0] != np.round(y_pred_class[0]):
1069+
index = index + 1
1070+
continue
1071+
else:
1072+
tcp_pred_list.append(y_pred_class[0])
1073+
tcp_true_list.append(y_true_class[0])
1074+
tcp_list.append(fnum + 1)
1075+
index = index + int(VIDEO_LENGTH / 2)
1076+
# Break from the for loop
1077+
break
1078+
1079+
1080+
# then after each epoch
1081+
avg_test_c_loss = np.mean(np.asarray(test_c_loss, dtype=np.float32), axis=0)
1082+
1083+
test_prec, test_rec, test_fbeta, test_support = get_sklearn_metrics(np.asarray(y_test_true),
1084+
np.asarray(y_test_pred),
1085+
avg='binary',
1086+
pos_label=1)
1087+
print("\nAvg test_c_loss: " + str(avg_test_c_loss))
1088+
print("Mean time to change prediction: " + str(np.mean(np.asarray(tcp_list))))
1089+
print("Standard Deviation " + str(np.std(np.asarray(tcp_list))))
1090+
print ("Number of correct predictions " + str(len(tcp_list)))
1091+
print("Test Prec: %.4f, Recall: %.4f, Fbeta: %.4f" % (test_prec, test_rec, test_fbeta))
1092+
1093+
print("Classification Report")
1094+
print(get_classification_report(np.asarray(y_test_true), np.asarray(y_test_pred)))
1095+
1096+
print("Confusion matrix")
1097+
tn, fp, fn, tp = confusion_matrix(y_test_true, np.round(y_test_pred)).ravel()
1098+
print("TN: %.2f, FP: %.2f, FN: %.2f, TP: %.2f" % (tn, fp, fn, tp))
1099+
1100+
print ("-------------------------------------------")
1101+
print ("Test cases where there is a change in label")
1102+
1103+
test_prec, test_rec, test_fbeta, test_support = get_sklearn_metrics(np.asarray(tcp_true_list),
1104+
np.asarray(tcp_pred_list),
1105+
avg='binary',
1106+
pos_label=1)
1107+
print("Test Prec: %.4f, Recall: %.4f, Fbeta: %.4f" % (test_prec, test_rec, test_fbeta))
1108+
1109+
test_acc = accuracy_score(tcp_true_list, np.round(tcp_pred_list))
1110+
print("Test Accuracy: %.4f" % (test_acc))
1111+
1112+
avg_prec = average_precision_score(tcp_true_list, tcp_pred_list)
1113+
print("Average precision: %.4f" % (avg_prec))
1114+
1115+
precisions, recalls, thresholds = precision_recall_curve(tcp_true_list, tcp_pred_list)
1116+
print("PR curve precisions: " + str(precisions))
1117+
print("PR curve recalls: " + str(recalls))
1118+
print("PR curve thresholds: " + str(thresholds))
1119+
print("PR curve prec mean: %.4f" % (np.mean(precisions)))
1120+
print("PR curve prec std: %.4f" % (np.std(precisions)))
1121+
print("Number of thresholds: %.4f" % (len(thresholds)))
1122+
1123+
print("Classification Report")
1124+
print(get_classification_report(np.asarray(tcp_true_list), np.asarray(tcp_pred_list)))
1125+
1126+
print("Confusion matrix")
1127+
tn, fp, fn, tp = confusion_matrix(tcp_true_list, np.round(tcp_pred_list)).ravel()
1128+
print("TN: %.2f, FP: %.2f, FN: %.2f, TP: %.2f" % (tn, fp, fn, tp))
1129+
1130+
1131+
1132+
9121133
def get_args():
9131134
parser = argparse.ArgumentParser()
9141135
parser.add_argument("--mode", type=str)

0 commit comments

Comments
 (0)