|
28 | 28 | from sklearn.metrics import classification_report |
29 | 29 | from sklearn.metrics import precision_recall_fscore_support |
30 | 30 | from sklearn.metrics import confusion_matrix |
| 31 | +from sklearn.metrics import precision_recall_curve |
| 32 | +from sklearn.metrics import average_precision_score |
| 33 | +from sklearn.metrics import accuracy_score |
31 | 34 | from image_utils import random_rotation |
32 | 35 | from image_utils import random_shift |
33 | 36 | from image_utils import flip_axis |
@@ -883,32 +886,250 @@ def test(CLA_WEIGHTS): |
883 | 886 |
|
884 | 887 | # then after each epoch |
885 | 888 | avg_test_c_loss = np.mean(np.asarray(test_c_loss, dtype=np.float32), axis=0) |
886 | | - |
887 | 889 | test_prec, test_rec, test_fbeta, test_support = get_sklearn_metrics(np.asarray(y_test_true), |
888 | | - np.asarray(y_test_pred), |
889 | | - avg='binary', |
890 | | - pos_label=1) |
| 890 | + np.asarray(y_test_pred), |
| 891 | + avg='binary', |
| 892 | + pos_label=1) |
891 | 893 | print("\nAvg test_c_loss: " + str(avg_test_c_loss)) |
892 | 894 | print("Test Prec: %.4f, Recall: %.4f, Fbeta: %.4f" % (test_prec, test_rec, test_fbeta)) |
893 | 895 |
|
894 | | - print ("Classification Report") |
| 896 | + test_acc = accuracy_score(y_test_true, np.round(y_test_pred)) |
| 897 | + print("Test Accuracy: %.4f" % (test_acc)) |
| 898 | + |
| 899 | + avg_prec = average_precision_score(y_test_true, y_test_pred) |
| 900 | + print("Average precision: %.4f" % (avg_prec)) |
| 901 | + |
| 902 | + precisions, recalls, thresholds = precision_recall_curve(y_test_true, y_test_pred) |
| 903 | + print("PR curve precisions: " + str(precisions)) |
| 904 | + print("PR curve recalls: " + str(recalls)) |
| 905 | + print("PR curve thresholds: " + str(thresholds)) |
| 906 | + print("PR curve prec mean: %.4f" %(np.mean(precisions))) |
| 907 | + print("PR curve prec std: %.4f" %(np.std(precisions))) |
| 908 | + print("Number of thresholds: %.4f" %(len(thresholds))) |
| 909 | + |
| 910 | + print("Classification Report") |
895 | 911 | print(get_classification_report(np.asarray(y_test_true), np.asarray(y_test_pred))) |
896 | 912 |
|
897 | | - print ("Confusion matrix") |
| 913 | + print("Confusion matrix") |
898 | 914 | tn, fp, fn, tp = confusion_matrix(y_test_true, np.round(y_test_pred)).ravel() |
899 | | - print ("TN: %.2f, FP: %.2f, FN: %.2f, TP: %.2f" % (tn, fp, fn, tp)) |
| 915 | + print("TN: %.2f, FP: %.2f, FN: %.2f, TP: %.2f" % (tn, fp, fn, tp)) |
900 | 916 |
|
901 | 917 | print("Mean time taken to make " + str(NB_TEST_ITERATIONS) + " predictions: %f" |
902 | 918 | % (np.mean(np.asarray(iter_endtime) - np.asarray(iter_starttime)))) |
903 | 919 | print("Standard Deviation %f" |
904 | 920 | % (np.std(np.asarray(iter_endtime) - np.asarray(iter_starttime)))) |
905 | 921 |
|
906 | | - print("Mean time taken to load and process " + str(NB_TEST_ITERATIONS) + " predictions: %f" |
| 922 | + print("Mean time taken to make load and process" + str(NB_TEST_ITERATIONS) + " predictions: %f" |
907 | 923 | % (np.mean(np.asarray(iter_endtime) - np.asarray(iter_loadtime)))) |
908 | 924 | print("Standard Deviation %f" |
909 | 925 | % (np.std(np.asarray(iter_endtime) - np.asarray(iter_loadtime)))) |
910 | 926 |
|
911 | 927 |
|
| 928 | +def test_mtcp(CLA_WEIGHTS): |
| 929 | + |
| 930 | + if not os.path.exists(TEST_RESULTS_DIR + '/pred/'): |
| 931 | + os.mkdir(TEST_RESULTS_DIR + '/pred/') |
| 932 | + |
| 933 | + # Setup test |
| 934 | + test_frames_source = hkl.load(os.path.join(TEST_DATA_DIR, 'sources_test_208.hkl')) |
| 935 | + # test_videos_list = get_video_lists(frames_source=test_frames_source, stride=8, frame_skip=0) |
| 936 | + test_videos_list = get_video_lists(frames_source=test_frames_source, stride=16, frame_skip=0) |
| 937 | + # test_videos_list = get_video_lists(frames_source=test_frames_source, stride=16, frame_skip=2) |
| 938 | + # Load test action annotations |
| 939 | + test_action_labels = hkl.load(os.path.join(TEST_DATA_DIR, 'annotations_test_208.hkl')) |
| 940 | + test_ped_action_classes, test_ped_class_count = get_action_classes(test_action_labels, mode='sigmoid') |
| 941 | + print("Test Stats: " + str(test_ped_class_count)) |
| 942 | + |
| 943 | + # Build the Spatio-temporal Autoencoder |
| 944 | + print("Creating models.") |
| 945 | + # Build stacked classifier |
| 946 | + # classifier = pretrained_c3d() |
| 947 | + classifier = ensemble_c3d() |
| 948 | + # classifier = c3d_scratch() |
| 949 | + classifier.compile(loss="binary_crossentropy", |
| 950 | + optimizer=OPTIM_C, |
| 951 | + # metrics=[metric_precision, metric_recall, metric_mpca, 'accuracy']) |
| 952 | + metrics=['acc']) |
| 953 | + |
| 954 | + run_utilities(classifier, CLA_WEIGHTS) |
| 955 | + |
| 956 | + n_test_videos = test_videos_list.shape[0] |
| 957 | + |
| 958 | + NB_TEST_ITERATIONS = int(n_test_videos / TEST_BATCH_SIZE) |
| 959 | + # NB_TEST_ITERATIONS = 5 |
| 960 | + |
| 961 | + # Setup TensorBoard Callback |
| 962 | + TC_cla = tb_callback.TensorBoard(log_dir=TF_LOG_CLA_DIR, histogram_freq=0, write_graph=False, |
| 963 | + write_images=False) |
| 964 | + LRS_clas = lrs_callback.LearningRateScheduler(schedule=schedule) |
| 965 | + LRS_clas.set_model(classifier) |
| 966 | + if CLASSIFIER: |
| 967 | + print("Testing Classifier...") |
| 968 | + # Run over test data |
| 969 | + print('') |
| 970 | + # Time to correct prediction |
| 971 | + tcp_list = [] |
| 972 | + tcp_true_list = [] |
| 973 | + tcp_pred_list = [] |
| 974 | + y_test_pred = [] |
| 975 | + y_test_true = [] |
| 976 | + test_c_loss = [] |
| 977 | + index = 0 |
| 978 | + tcp = 1 |
| 979 | + while index < NB_TEST_ITERATIONS: |
| 980 | + X, y = load_X_y(test_videos_list, index, TEST_DATA_DIR, test_ped_action_classes, |
| 981 | + batch_size=TEST_BATCH_SIZE) |
| 982 | + |
| 983 | + y_past_class = y[:, 0] |
| 984 | + y_end_class = y[:,-1] |
| 985 | + |
| 986 | + if y_end_class[0] == y_past_class[0]: |
| 987 | + index = index + 1 |
| 988 | + continue |
| 989 | + else: |
| 990 | + stdout.write("\rIter: " + str(index) + "/" + str(NB_TEST_ITERATIONS - 1)) |
| 991 | + stdout.flush() |
| 992 | + for fnum in range (int(VIDEO_LENGTH/2) + 1): |
| 993 | + |
| 994 | + X, y = load_X_y(test_videos_list, index, TEST_DATA_DIR, test_ped_action_classes, |
| 995 | + batch_size=TEST_BATCH_SIZE) |
| 996 | + X_test = X |
| 997 | + |
| 998 | + y_true_imgs = X[:, int(VIDEO_LENGTH / 2):] |
| 999 | + y_true_class = y[:, VIDEO_LENGTH - fnum - 1] |
| 1000 | + if y[:, 0] == y_true_class[0]: |
| 1001 | + break |
| 1002 | + |
| 1003 | + if (fnum + 1 > 16): |
| 1004 | + tcp_pred_list.append(y_pred_class[0]) |
| 1005 | + tcp_true_list.append(y_true_class[0]) |
| 1006 | + break |
| 1007 | + |
| 1008 | + y_pred_class = classifier.predict(X_test, verbose=0) |
| 1009 | + y_test_pred.extend(classifier.predict(X_test, verbose=0)) |
| 1010 | + test_c_loss.append(classifier.test_on_batch(X_test, y_true_class)) |
| 1011 | + y_test_true.extend(y_true_class) |
| 1012 | + |
| 1013 | + test_ped_pred_class = classifier.predict(X_test, verbose=0) |
| 1014 | + # pred_seq = arrange_images(np.concatenate((X_train, predicted_images), axis=1)) |
| 1015 | + pred_seq = arrange_images(X_test) |
| 1016 | + pred_seq = pred_seq * 127.5 + 127.5 |
| 1017 | + |
| 1018 | + # Save generated images to file |
| 1019 | + z = encoder.predict(X_test) |
| 1020 | + test_predicted_images = decoder.predict(z) |
| 1021 | + test_ped_pred_class = sclassifier.predict(X_test, verbose=0) |
| 1022 | + pred_seq = arrange_images(np.concatenate((X_test, test_predicted_images), axis=1)) |
| 1023 | + pred_seq = pred_seq * 127.5 + 127.5 |
| 1024 | + |
| 1025 | + truth_image = arrange_images(y_true_imgs) |
| 1026 | + truth_image = truth_image * 127.5 + 127.5 |
| 1027 | + |
| 1028 | + font = cv2.FONT_HERSHEY_SIMPLEX |
| 1029 | + y_orig_classes = y[:, 0: int(VIDEO_LENGTH / 2)] |
| 1030 | + y_true_classes = y[:, int(VIDEO_LENGTH / 2):] |
| 1031 | + |
| 1032 | + # Add labels as text to the image |
| 1033 | + for k in range(TEST_BATCH_SIZE): |
| 1034 | + for j in range(int(VIDEO_LENGTH / 2)): |
| 1035 | + if y_orig_classes[k, j] > 0.5: |
| 1036 | + label_orig = "crossing" |
| 1037 | + else: |
| 1038 | + label_orig = "not crossing" |
| 1039 | + |
| 1040 | + if y_true_classes[k][j] > 0.5: |
| 1041 | + label_true = "crossing" |
| 1042 | + else: |
| 1043 | + label_true = "not crossing" |
| 1044 | + |
| 1045 | + if test_ped_pred_class[k][0] > 0.5: |
| 1046 | + label_pred = "crossing" |
| 1047 | + else: |
| 1048 | + label_pred = "not crossing" |
| 1049 | + |
| 1050 | + cv2.putText(pred_seq, label_orig, |
| 1051 | + (2 + j * (208), 114 + k * 128), font, 0.5, (255, 255, 255), 1, |
| 1052 | + cv2.LINE_AA) |
| 1053 | + cv2.putText(pred_seq, label_pred, |
| 1054 | + (2 + (j + 16) * (208), 114 + k * 128), font, 0.5, (255, 255, 255), 1, |
| 1055 | + cv2.LINE_AA) |
| 1056 | + cv2.putText(pred_seq, 'truth: ' + label_true, |
| 1057 | + (2 + (j + 16) * (208), 94 + k * 128), font, 0.5, (255, 255, 255), 1, |
| 1058 | + cv2.LINE_AA) |
| 1059 | + cv2.putText(truth_image, label_true, |
| 1060 | + (2 + j * (208), 114 + k * 128), font, 0.5, (255, 255, 255), 1, |
| 1061 | + cv2.LINE_AA) |
| 1062 | + |
| 1063 | + cv2.imwrite(os.path.join(TEST_RESULTS_DIR + '/mtcp-pred//', str(index) + "_cla_test_pred.png"), |
| 1064 | + pred_seq) |
| 1065 | + cv2.imwrite(os.path.join(TEST_RESULTS_DIR + '/mtcp-truth/', str(index) + "_cla_test_truth.png"), |
| 1066 | + truth_image) |
| 1067 | + |
| 1068 | + if y_true_class[0] != np.round(y_pred_class[0]): |
| 1069 | + index = index + 1 |
| 1070 | + continue |
| 1071 | + else: |
| 1072 | + tcp_pred_list.append(y_pred_class[0]) |
| 1073 | + tcp_true_list.append(y_true_class[0]) |
| 1074 | + tcp_list.append(fnum + 1) |
| 1075 | + index = index + int(VIDEO_LENGTH / 2) |
| 1076 | + # Break from the for loop |
| 1077 | + break |
| 1078 | + |
| 1079 | + |
| 1080 | + # then after each epoch |
| 1081 | + avg_test_c_loss = np.mean(np.asarray(test_c_loss, dtype=np.float32), axis=0) |
| 1082 | + |
| 1083 | + test_prec, test_rec, test_fbeta, test_support = get_sklearn_metrics(np.asarray(y_test_true), |
| 1084 | + np.asarray(y_test_pred), |
| 1085 | + avg='binary', |
| 1086 | + pos_label=1) |
| 1087 | + print("\nAvg test_c_loss: " + str(avg_test_c_loss)) |
| 1088 | + print("Mean time to change prediction: " + str(np.mean(np.asarray(tcp_list)))) |
| 1089 | + print("Standard Deviation " + str(np.std(np.asarray(tcp_list)))) |
| 1090 | + print ("Number of correct predictions " + str(len(tcp_list))) |
| 1091 | + print("Test Prec: %.4f, Recall: %.4f, Fbeta: %.4f" % (test_prec, test_rec, test_fbeta)) |
| 1092 | + |
| 1093 | + print("Classification Report") |
| 1094 | + print(get_classification_report(np.asarray(y_test_true), np.asarray(y_test_pred))) |
| 1095 | + |
| 1096 | + print("Confusion matrix") |
| 1097 | + tn, fp, fn, tp = confusion_matrix(y_test_true, np.round(y_test_pred)).ravel() |
| 1098 | + print("TN: %.2f, FP: %.2f, FN: %.2f, TP: %.2f" % (tn, fp, fn, tp)) |
| 1099 | + |
| 1100 | + print ("-------------------------------------------") |
| 1101 | + print ("Test cases where there is a change in label") |
| 1102 | + |
| 1103 | + test_prec, test_rec, test_fbeta, test_support = get_sklearn_metrics(np.asarray(tcp_true_list), |
| 1104 | + np.asarray(tcp_pred_list), |
| 1105 | + avg='binary', |
| 1106 | + pos_label=1) |
| 1107 | + print("Test Prec: %.4f, Recall: %.4f, Fbeta: %.4f" % (test_prec, test_rec, test_fbeta)) |
| 1108 | + |
| 1109 | + test_acc = accuracy_score(tcp_true_list, np.round(tcp_pred_list)) |
| 1110 | + print("Test Accuracy: %.4f" % (test_acc)) |
| 1111 | + |
| 1112 | + avg_prec = average_precision_score(tcp_true_list, tcp_pred_list) |
| 1113 | + print("Average precision: %.4f" % (avg_prec)) |
| 1114 | + |
| 1115 | + precisions, recalls, thresholds = precision_recall_curve(tcp_true_list, tcp_pred_list) |
| 1116 | + print("PR curve precisions: " + str(precisions)) |
| 1117 | + print("PR curve recalls: " + str(recalls)) |
| 1118 | + print("PR curve thresholds: " + str(thresholds)) |
| 1119 | + print("PR curve prec mean: %.4f" % (np.mean(precisions))) |
| 1120 | + print("PR curve prec std: %.4f" % (np.std(precisions))) |
| 1121 | + print("Number of thresholds: %.4f" % (len(thresholds))) |
| 1122 | + |
| 1123 | + print("Classification Report") |
| 1124 | + print(get_classification_report(np.asarray(tcp_true_list), np.asarray(tcp_pred_list))) |
| 1125 | + |
| 1126 | + print("Confusion matrix") |
| 1127 | + tn, fp, fn, tp = confusion_matrix(tcp_true_list, np.round(tcp_pred_list)).ravel() |
| 1128 | + print("TN: %.2f, FP: %.2f, FN: %.2f, TP: %.2f" % (tn, fp, fn, tp)) |
| 1129 | + |
| 1130 | + |
| 1131 | + |
| 1132 | + |
912 | 1133 | def get_args(): |
913 | 1134 | parser = argparse.ArgumentParser() |
914 | 1135 | parser.add_argument("--mode", type=str) |
|
0 commit comments