Skip to content

Commit 3897676

Browse files
committed
focal loss gets used after autotune and listed in params-output.
sigmoid is applied for metrics calculation, when using test data
1 parent 78b2561 commit 3897676

2 files changed

Lines changed: 16 additions & 0 deletions

File tree

birdnet_analyzer/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,9 @@ def save_model_params(path):
469469
"Upsamling ratio",
470470
"use mixup",
471471
"use label smoothing",
472+
"use focal loss",
473+
"focal loss alpha",
474+
"focal loss gamma",
472475
"BirdNET Model version",
473476
),
474477
(
@@ -483,6 +486,9 @@ def save_model_params(path):
483486
cfg.UPSAMPLING_RATIO,
484487
cfg.TRAIN_WITH_MIXUP,
485488
cfg.TRAIN_WITH_LABEL_SMOOTHING,
489+
cfg.TRAIN_WITH_FOCAL_LOSS,
490+
cfg.FOCAL_LOSS_ALPHA,
491+
cfg.FOCAL_LOSS_GAMMA,
486492
cfg.MODEL_VERSION,
487493
),
488494
)

birdnet_analyzer/train/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,10 @@ def run_trial(self, trial, *args, **kwargs):
512512
cfg.UPSAMPLING_RATIO = best_params["upsampling_ratio"]
513513
cfg.TRAIN_WITH_MIXUP = best_params["mixup"]
514514
cfg.TRAIN_WITH_LABEL_SMOOTHING = best_params["label_smoothing"]
515+
cfg.TRAIN_WITH_FOCAL_LOSS = best_params["focal_loss"]
516+
if cfg.TRAIN_WITH_FOCAL_LOSS:
517+
cfg.FOCAL_LOSS_ALPHA = best_params["focal_loss_alpha"]
518+
cfg.FOCAL_LOSS_GAMMA = best_params["focal_loss_gamma"]
515519

516520
print("Best params: ")
517521
print("hidden_units: ", cfg.TRAIN_HIDDEN_UNITS)
@@ -523,6 +527,10 @@ def run_trial(self, trial, *args, **kwargs):
523527
print("upsampling_mode: ", cfg.UPSAMPLING_MODE)
524528
print("mixup: ", cfg.TRAIN_WITH_MIXUP)
525529
print("label_smoothing: ", cfg.TRAIN_WITH_LABEL_SMOOTHING)
530+
print("focal_loss: ", cfg.TRAIN_WITH_FOCAL_LOSS)
531+
if cfg.TRAIN_WITH_FOCAL_LOSS:
532+
print("focal_loss_alpha: ", cfg.FOCAL_LOSS_ALPHA)
533+
print("focal_loss_gamma: ", cfg.FOCAL_LOSS_GAMMA)
526534

527535
# Build model
528536
print("Building model...", flush=True)
@@ -740,6 +748,8 @@ def evaluate_model(classifier, x_test, y_test, labels, threshold=None):
740748
# Make predictions
741749
y_pred_prob = classifier.predict(x_test)
742750

751+
y_pred_prob = model.flat_sigmoid(y_pred_prob, sensitivity=-1, bias=1.0)
752+
743753
# Calculate metrics for each class
744754
metrics = {}
745755

0 commit comments

Comments
 (0)