|
40 | 40 | ] |
41 | 41 | }, |
42 | 42 | { |
| 43 | + "attachments": {}, |
43 | 44 | "cell_type": "markdown", |
44 | 45 | "metadata": { |
45 | 46 | "id": "IgYKebt871EK" |
|
59 | 60 | ] |
60 | 61 | }, |
61 | 62 | { |
| 63 | + "attachments": {}, |
62 | 64 | "cell_type": "markdown", |
63 | 65 | "metadata": { |
64 | 66 | "id": "6JTRoM7E71EU" |
|
95 | 97 | ] |
96 | 98 | }, |
97 | 99 | { |
| 100 | + "attachments": {}, |
98 | 101 | "cell_type": "markdown", |
99 | 102 | "metadata": { |
100 | 103 | "id": "6VKVqLb371EV" |
|
130 | 133 | ] |
131 | 134 | }, |
132 | 135 | { |
| 136 | + "attachments": {}, |
133 | 137 | "cell_type": "markdown", |
134 | 138 | "metadata": { |
135 | 139 | "id": "cREmhMWJ71EX" |
|
143 | 147 | ] |
144 | 148 | }, |
145 | 149 | { |
| 150 | + "attachments": {}, |
146 | 151 | "cell_type": "markdown", |
147 | 152 | "metadata": { |
148 | 153 | "id": "1NhotGiT71EY" |
|
199 | 204 | ] |
200 | 205 | }, |
201 | 206 | { |
| 207 | + "attachments": {}, |
202 | 208 | "cell_type": "markdown", |
203 | 209 | "metadata": { |
204 | 210 | "id": "LgTG6buf71Ea" |
|
256 | 262 | ] |
257 | 263 | }, |
258 | 264 | { |
| 265 | + "attachments": {}, |
259 | 266 | "cell_type": "markdown", |
260 | 267 | "metadata": { |
261 | 268 | "id": "SzFGcrhv71Ed" |
|
304 | 311 | "test_imgs = test_loader.get_all_faces()\n", |
305 | 312 | "\n", |
306 | 313 | "# Call the Capsa-wrapped classifier to generate outputs: predictions, uncertainty, and bias!\n", |
307 | | - "predictions, uncertainty, bias = wrapped_model.predict(test_imgs, batch_size=512)" |
| 314 | + "#predictions, uncertainty, bias = wrapped_model.predict(test_imgs, batch_size=512)\n", |
| 315 | + "out = wrapped_model.predict(test_imgs, batch_size=512)\n" |
308 | 316 | ] |
309 | 317 | }, |
310 | 318 | { |
| 319 | + "attachments": {}, |
311 | 320 | "cell_type": "markdown", |
312 | 321 | "metadata": { |
313 | 322 | "id": "629ng-_H6WOk" |
|
329 | 338 | "### Analyzing representation bias scores ###\n", |
330 | 339 | "\n", |
331 | 340 | "# Sort according to lowest to highest representation scores\n", |
332 | | - "indices = np.argsort(bias, axis=None) # sort the score values themselves\n", |
| 341 | + "indices = np.argsort(out.bias, axis=None) # sort the score values themselves\n", |
333 | 342 | "sorted_images = test_imgs[indices] # sort images from lowest to highest representations\n", |
334 | | - "sorted_biases = bias[indices] # order the representation bias scores\n", |
335 | | - "sorted_preds = predictions[indices] # order the prediction values\n", |
| 343 | + "sorted_biases = out.bias.numpy()[indices] # order the representation bias scores\n", |
| 344 | + "sorted_preds = out.y_hat.numpy()[indices] # order the prediction values\n", |
336 | 345 | "\n", |
337 | 346 | "\n", |
338 | 347 | "# Visualize the 20 images with the lowest and highest representation in the test dataset\n", |
|
345 | 354 | ] |
346 | 355 | }, |
347 | 356 | { |
| 357 | + "attachments": {}, |
348 | 358 | "cell_type": "markdown", |
349 | 359 | "metadata": { |
350 | 360 | "id": "-JYmGMJF71Ef" |
|
368 | 378 | ] |
369 | 379 | }, |
370 | 380 | { |
| 381 | + "attachments": {}, |
371 | 382 | "cell_type": "markdown", |
372 | 383 | "metadata": { |
373 | 384 | "id": "i8ERzg2-71Ef" |
|
389 | 400 | ] |
390 | 401 | }, |
391 | 402 | { |
| 403 | + "attachments": {}, |
392 | 404 | "cell_type": "markdown", |
393 | 405 | "metadata": { |
394 | 406 | "id": "cRNV-3SU71Eg" |
|
404 | 416 | ] |
405 | 417 | }, |
406 | 418 | { |
| 419 | + "attachments": {}, |
407 | 420 | "cell_type": "markdown", |
408 | 421 | "metadata": { |
409 | 422 | "id": "ww5lx7ue71Eg" |
|
420 | 433 | ] |
421 | 434 | }, |
422 | 435 | { |
| 436 | + "attachments": {}, |
423 | 437 | "cell_type": "markdown", |
424 | 438 | "metadata": { |
425 | 439 | "id": "NEfeWo2p7wKm" |
|
442 | 456 | "### Analyzing epistemic uncertainty estimates ###\n", |
443 | 457 | "\n", |
444 | 458 | "# Sort according to epistemic uncertainty estimates\n", |
445 | | - "epistemic_indices = np.argsort(uncertainty, axis=None) # sort the uncertainty values\n", |
| 459 | + "epistemic_indices = np.argsort(out.epistemic, axis=None) # sort the uncertainty values\n", |
446 | 460 | "epistemic_images = test_imgs[epistemic_indices] # sort images from lowest to highest uncertainty\n", |
447 | | - "sorted_epistemic = uncertainty[epistemic_indices] # order the uncertainty scores\n", |
448 | | - "sorted_epistemic_preds = predictions[epistemic_indices] # order the prediction values\n", |
| 461 | + "sorted_epistemic = out.epistemic.numpy()[epistemic_indices] # order the uncertainty scores\n", |
| 462 | + "sorted_epistemic_preds = out.y_hat.numpy()[epistemic_indices] # order the prediction values\n", |
449 | 463 | "\n", |
450 | 464 | "\n", |
451 | 465 | "# Visualize the 20 images with the LEAST and MOST epistemic uncertainty\n", |
|
458 | 472 | ] |
459 | 473 | }, |
460 | 474 | { |
| 475 | + "attachments": {}, |
461 | 476 | "cell_type": "markdown", |
462 | 477 | "metadata": { |
463 | 478 | "id": "L0dA8EyX71Eh" |
|
481 | 496 | ] |
482 | 497 | }, |
483 | 498 | { |
| 499 | + "attachments": {}, |
484 | 500 | "cell_type": "markdown", |
485 | 501 | "metadata": { |
486 | 502 | "id": "iyn0IE6x71Eh" |
|
496 | 512 | ] |
497 | 513 | }, |
498 | 514 | { |
| 515 | + "attachments": {}, |
499 | 516 | "cell_type": "markdown", |
500 | 517 | "metadata": { |
501 | 518 | "id": "XbwRbesM71Eh" |
|
561 | 578 | "\n", |
562 | 579 | " # After the epoch is done, recompute data sampling proabilities \n", |
563 | 580 | " # according to the inverse of the bias\n", |
564 | | - " pred, unc, bias = wrapper(train_imgs)\n", |
| 581 | + " out = wrapper(train_imgs)\n", |
565 | 582 | "\n", |
566 | 583 | " # Increase the probability of sampling under-represented datapoints by setting \n", |
567 | 584 | " # the probability to the **inverse** of the biases\n", |
568 | | - " inverse_bias = 1.0 / (bias.numpy() + 1e-7)\n", |
| 585 | + " inverse_bias = 1.0 / (np.mean(out.bias.numpy(),axis=-1) + 1e-7)\n", |
569 | 586 | "\n", |
570 | 587 | " # Normalize the inverse biases in order to convert them to probabilities\n", |
571 | 588 | " p_faces = inverse_bias / np.sum(inverse_bias)\n", |
|
575 | 592 | ] |
576 | 593 | }, |
577 | 594 | { |
| 595 | + "attachments": {}, |
578 | 596 | "cell_type": "markdown", |
579 | 597 | "metadata": { |
580 | 598 | "id": "SwXrAeBo71Ej" |
|
598 | 616 | "### Evaluation of debiased model ###\n", |
599 | 617 | "\n", |
600 | 618 | "# Get classification predictions, uncertainties, and representation bias scores\n", |
601 | | - "pred, unc, bias = wrapper.predict(test_imgs)\n", |
| 619 | + "out = wrapper.predict(test_imgs)\n", |
602 | 620 | "\n", |
603 | 621 | "# Sort according to lowest to highest representation scores\n", |
604 | | - "indices = np.argsort(bias, axis=None)\n", |
| 622 | + "indices = np.argsort(out.bias, axis=None)\n", |
605 | 623 | "bias_images = test_imgs[indices] # sort the images\n", |
606 | | - "sorted_bias = bias[indices] # sort the representation bias scores\n", |
607 | | - "sorted_bias_preds = pred[indices] # sort the predictions\n", |
| 624 | + "sorted_bias = out.bias.numpy()[indices] # sort the representation bias scores\n", |
| 625 | + "sorted_bias_preds = out.y_hat.numpy()[indices] # sort the predictions\n", |
608 | 626 | "\n", |
609 | 627 | "# Plot the representation bias vs. the accuracy\n", |
610 | 628 | "plt.xlabel(\"Density (Representation)\")\n", |
|
613 | 631 | ] |
614 | 632 | }, |
615 | 633 | { |
| 634 | + "attachments": {}, |
616 | 635 | "cell_type": "markdown", |
617 | 636 | "metadata": { |
618 | 637 | "id": "d1cEEnII71Ej" |
|
681 | 700 | "name": "python", |
682 | 701 | "nbconvert_exporter": "python", |
683 | 702 | "pygments_lexer": "ipython3", |
684 | | - "version": "3.9.6 (default, Oct 18 2022, 12:41:40) \n[Clang 14.0.0 (clang-1400.0.29.202)]" |
| 703 | + "version": "3.9.16" |
685 | 704 | }, |
686 | 705 | "vscode": { |
687 | 706 | "interpreter": { |
|
0 commit comments