Skip to content

Commit d2e49cd

Browse files
Merge pull request #792 from birdnet-team/focal-loss-autotune
Fix focal loss in autotune and apply sigmoid for test data evaluation
2 parents ee559da + 8e3dee4 commit d2e49cd

2 files changed

Lines changed: 16 additions & 0 deletions

File tree

birdnet_analyzer/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,9 @@ def save_model_params(path):
478478
"Upsamling ratio",
479479
"use mixup",
480480
"use label smoothing",
481+
"use focal loss",
482+
"focal loss alpha",
483+
"focal loss gamma",
481484
"BirdNET Model version",
482485
),
483486
(
@@ -492,6 +495,9 @@ def save_model_params(path):
492495
cfg.UPSAMPLING_RATIO,
493496
cfg.TRAIN_WITH_MIXUP,
494497
cfg.TRAIN_WITH_LABEL_SMOOTHING,
498+
cfg.TRAIN_WITH_FOCAL_LOSS,
499+
cfg.FOCAL_LOSS_ALPHA,
500+
cfg.FOCAL_LOSS_GAMMA,
495501
cfg.MODEL_VERSION,
496502
),
497503
)

birdnet_analyzer/train/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,10 @@ def run_trial(self, trial, *args, **kwargs):
496496
cfg.UPSAMPLING_RATIO = best_params["upsampling_ratio"]
497497
cfg.TRAIN_WITH_MIXUP = best_params["mixup"]
498498
cfg.TRAIN_WITH_LABEL_SMOOTHING = best_params["label_smoothing"]
499+
cfg.TRAIN_WITH_FOCAL_LOSS = best_params["focal_loss"]
500+
if cfg.TRAIN_WITH_FOCAL_LOSS:
501+
cfg.FOCAL_LOSS_ALPHA = best_params["focal_loss_alpha"]
502+
cfg.FOCAL_LOSS_GAMMA = best_params["focal_loss_gamma"]
499503

500504
print("Best params: ")
501505
print("hidden_units: ", cfg.TRAIN_HIDDEN_UNITS)
@@ -507,6 +511,10 @@ def run_trial(self, trial, *args, **kwargs):
507511
print("upsampling_mode: ", cfg.UPSAMPLING_MODE)
508512
print("mixup: ", cfg.TRAIN_WITH_MIXUP)
509513
print("label_smoothing: ", cfg.TRAIN_WITH_LABEL_SMOOTHING)
514+
print("focal_loss: ", cfg.TRAIN_WITH_FOCAL_LOSS)
515+
if cfg.TRAIN_WITH_FOCAL_LOSS:
516+
print("focal_loss_alpha: ", cfg.FOCAL_LOSS_ALPHA)
517+
print("focal_loss_gamma: ", cfg.FOCAL_LOSS_GAMMA)
510518

511519
# Build model
512520
print("Building model...", flush=True)
@@ -724,6 +732,8 @@ def evaluate_model(classifier, x_test, y_test, labels, threshold=None):
724732
# Make predictions
725733
y_pred_prob = classifier.predict(x_test)
726734

735+
y_pred_prob = model.flat_sigmoid(y_pred_prob, sensitivity=-1, bias=1.0)
736+
727737
# Calculate metrics for each class
728738
metrics = {}
729739

0 commit comments

Comments
 (0)