@@ -75,6 +75,7 @@ def __init__(self, data_path, batch_size, training=True):
7575
7676 self .train_inds = np .concatenate ((self .pos_train_inds , self .neg_train_inds ))
7777 self .batch_size = batch_size
78+ self .p_pos = np .ones (self .pos_train_inds .shape )/ len (self .pos_train_inds )
7879
7980 def get_train_size (self ):
8081 return self .pos_train_inds .shape [0 ] + self .neg_train_inds .shape [0 ]
@@ -84,7 +85,7 @@ def __len__(self):
8485
8586 def __getitem__ (self , index ):
8687 selected_pos_inds = np .random .choice (
87- self .pos_train_inds , size = self .batch_size // 2 , replace = False
88+ self .pos_train_inds , size = self .batch_size // 2 , replace = False , p = self . p_pos
8889 )
8990 selected_neg_inds = np .random .choice (
9091 self .neg_train_inds , size = self .batch_size // 2 , replace = False
@@ -94,8 +95,7 @@ def __getitem__(self, index):
9495 sorted_inds = np .sort (selected_inds )
9596 train_img = (self .images [sorted_inds ] / 255.0 ).astype (np .float32 )
9697 train_label = self .labels [sorted_inds , ...]
97- inds = np .random .permutation (np .arange (len (train_img )))
98- return np .array (train_img [inds ]), np .array (train_label [inds ])
98+ return np .array (train_img ), np .array (train_label )
9999
100100 def get_n_most_prob_faces (self , prob , n ):
101101 idx = np .argsort (prob )[::- 1 ]
@@ -121,7 +121,7 @@ def get_test_faces():
121121 return images ["LF" ], images ["LM" ], images ["DF" ], images ["DM" ]
122122
123123
124- def plot_k (imgs ):
124+ def plot_k (imgs , fname = None ):
125125 fig = plt .figure ()
126126 fig .subplots_adjust (hspace = 0.6 )
127127 num_images = len (imgs )
@@ -133,10 +133,12 @@ def plot_k(imgs):
133133 ax .imshow (img_to_show , interpolation = "nearest" )
134134 plt .subplots_adjust (wspace = 0.20 , hspace = 0.20 )
135135 plt .show ()
136+ if fname :
137+ plt .savefig (fname )
136138 plt .clf ()
137139
138140
139- def plot_percentile (imgs ):
141+ def plot_percentile (imgs , fname = None ):
140142 fig = plt .figure ()
141143 fig , axs = plt .subplots (1 , len (imgs ), figsize = (11 , 8 ))
142144 for img in range (len (imgs )):
@@ -145,3 +147,26 @@ def plot_percentile(imgs):
145147 ax .yaxis .set_visible (False )
146148 img_to_show = imgs [img ]
147149 ax .imshow (img_to_show , interpolation = "nearest" )
150+ if fname :
151+ plt .savefig (fname )
152+
153+ def plot_accuracy_vs_risk (sorted_images , sorted_uncertainty , sorted_preds , plot_title ):
154+ num_percentile_intervals = 10
155+ num_samples = len (sorted_images ) // num_percentile_intervals
156+ all_imgs = []
157+ all_unc = []
158+ all_acc = []
159+ for percentile in range (num_percentile_intervals ):
160+ cur_imgs = sorted_images [percentile * num_samples : (percentile + 1 ) * num_samples ]
161+ cur_unc = sorted_uncertainty [percentile * num_samples : (percentile + 1 ) * num_samples ]
162+ cur_predictions = tf .nn .sigmoid (sorted_preds [percentile * num_samples : (percentile + 1 ) * num_samples ])
163+ avged_imgs = tf .reduce_mean (cur_imgs , axis = 0 )
164+ all_imgs .append (avged_imgs )
165+ all_unc .append (tf .reduce_mean (cur_unc ))
166+ all_acc .append ((np .ones ((num_samples )) == np .rint (cur_predictions )).mean ())
167+
168+ plt .plot (np .arange (num_percentile_intervals ) * 10 , all_acc )
169+ plt .title (plot_title )
170+ plt .show ()
171+ plt .clf ()
172+ return all_imgs
0 commit comments