Skip to content

Commit a62d4ce

Browse files
Merge pull request #817 from birdnet-team/removed-embedding-normalization
removed the embedding normalization after data loading during training
2 parents 7ae6a07 + 804b097 commit a62d4ce

1 file changed

Lines changed: 0 additions & 29 deletions

File tree

birdnet_analyzer/train/utils.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -263,29 +263,6 @@ def load_data(data_path, allowed_folders):
263263
# Return only the valid labels for further use
264264
return x_train, y_train, x_test, y_test, valid_labels
265265

266-
267-
def normalize_embeddings(embeddings):
268-
"""
269-
Normalize embeddings to improve training stability and performance.
270-
271-
This applies L2 normalization to each embedding vector, which can help
272-
with convergence and model performance, especially when training on
273-
embeddings from different sources or domains.
274-
275-
Args:
276-
embeddings: numpy array of embedding vectors
277-
278-
Returns:
279-
Normalized embeddings array
280-
"""
281-
# Calculate L2 norm of each embedding vector
282-
norms = np.sqrt(np.sum(embeddings**2, axis=1, keepdims=True))
283-
# Avoid division by zero
284-
norms[norms == 0] = 1.0
285-
# Normalize each embedding vector
286-
return embeddings / norms
287-
288-
289266
def train_model(on_epoch_end=None, on_trial_result=None, on_data_load_end=None, autotune_directory="autotune"):
290267
"""Trains a custom classifier.
291268
@@ -310,12 +287,6 @@ def train_model(on_epoch_end=None, on_trial_result=None, on_data_load_end=None,
310287
if len(x_test) > 0:
311288
print(f"...Loaded {x_test.shape[0]} test samples.", flush=True)
312289

313-
# Normalize embeddings
314-
print("Normalizing embeddings...", flush=True)
315-
x_train = normalize_embeddings(x_train)
316-
if len(x_test) > 0:
317-
x_test = normalize_embeddings(x_test)
318-
319290
if cfg.AUTOTUNE:
320291
import gc
321292

0 commit comments

Comments
 (0)