Skip to content

Commit c4e8b04

Browse files
authored
Fix custom classifier false positives by removing BatchNormalization layers (#824)
Resolves train-inference mismatch that caused ~100% confidence scores on background noise in custom classifiers trained with v2.0.0+. Root cause: BatchNormalization layers were added in v2.0.0 but received different input distributions during training (L2-normalized embeddings) vs inference (raw embeddings), breaking learned decision boundaries. Solution: Remove BatchNormalization layers from build_linear_classifier() to restore reliable v1.5.1 behavior while keeping other v2.x improvements (L2 regularization, improved learning rate schedule). Fixes #823
1 parent a62d4ce commit c4e8b04

1 file changed

Lines changed: 0 additions & 6 deletions

File tree

birdnet_analyzer/model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -635,9 +635,6 @@ def build_linear_classifier(num_labels, input_size, hidden_units=0, dropout=0.0)
635635
# Input layer
636636
model.add(keras.layers.InputLayer(input_shape=(input_size,)))
637637

638-
# Batch normalization on input to standardize embeddings
639-
model.add(keras.layers.BatchNormalization())
640-
641638
# Optional L2 regularization for all dense layers
642639
regularizer = keras.regularizers.l2(1e-5)
643640

@@ -650,9 +647,6 @@ def build_linear_classifier(num_labels, input_size, hidden_units=0, dropout=0.0)
650647
# Add a hidden layer with L2 regularization
651648
model.add(keras.layers.Dense(hidden_units, activation="relu", kernel_regularizer=regularizer, kernel_initializer="he_normal"))
652649

653-
# Add another batch normalization after the hidden layer
654-
model.add(keras.layers.BatchNormalization())
655-
656650
# Dropout layer before output
657651
if dropout > 0:
658652
model.add(keras.layers.Dropout(dropout))

0 commit comments

Comments
 (0)