Skip to content

Commit fa187eb

Browse files
streamline chunk generation for embeddings and analyze (#727)
* streamline chunk generation for embeddings and analyze * - * - * - * timestamps for embeddings are now stored with respect to the audio speed --------- Co-authored-by: Max Mauermann <max-mauermann@web.de>
1 parent 928bf9f commit fa187eb

7 files changed

Lines changed: 85 additions & 128 deletions

File tree

birdnet_analyzer/analyze/utils.py

Lines changed: 65 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def combine_csv_files(saved_results: list[str]):
396396
f.write(out_string)
397397

398398

399-
def combine_results(saved_results: Sequence[dict[str, str]| None]):
399+
def combine_results(saved_results: Sequence[dict[str, str] | None]):
400400
"""
401401
Combines various types of result files based on the configuration settings.
402402
This function checks the types of results specified in the configuration
@@ -522,6 +522,56 @@ def get_raw_audio_from_file(fpath: str, offset, duration):
522522
return audio.split_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)
523523

524524

525+
def iterate_audio_chunks(fpath: str, embeddings: bool = False):
526+
"""Iterates over audio chunks from a file.
527+
528+
Args:
529+
fpath: Path to the audio file.
530+
offset: Offset in seconds to start reading the file.
531+
532+
Yields:
533+
Chunks of audio data.
534+
"""
535+
fileLengthSeconds = audio.get_audio_file_length(fpath)
536+
start, end = 0, cfg.SIG_LENGTH * cfg.AUDIO_SPEED
537+
duration = int(cfg.FILE_SPLITTING_DURATION / cfg.AUDIO_SPEED)
538+
539+
while start < fileLengthSeconds and not np.isclose(start, fileLengthSeconds):
540+
chunks = get_raw_audio_from_file(fpath, start, duration)
541+
samples = []
542+
timestamps = []
543+
544+
if not chunks:
545+
break
546+
547+
for chunk_index, chunk in enumerate(chunks):
548+
# Add to batch
549+
samples.append(chunk)
550+
timestamps.append([round(start, 1), round(end, 1)])
551+
552+
# Advance start and end
553+
start += (cfg.SIG_LENGTH - cfg.SIG_OVERLAP) * cfg.AUDIO_SPEED
554+
end = min(start + cfg.SIG_LENGTH * cfg.AUDIO_SPEED, fileLengthSeconds)
555+
556+
# Check if batch is full or last chunk
557+
if len(samples) < cfg.BATCH_SIZE and chunk_index < len(chunks) - 1:
558+
continue
559+
560+
# Predict
561+
p = model.embeddings(samples) if embeddings else predict(samples)
562+
563+
# Add to results
564+
for i in range(len(samples)):
565+
# Get timestamp
566+
s_start, s_end = timestamps[i]
567+
568+
yield s_start, s_end, p[i]
569+
570+
# Clear batch
571+
samples = []
572+
timestamps = []
573+
574+
525575
def predict(samples):
526576
"""Predicts the classes for the given samples.
527577
@@ -600,76 +650,31 @@ def analyze_file(item) -> dict[str, str] | None:
600650

601651
# Start time
602652
start_time = datetime.datetime.now()
603-
duration = int(cfg.FILE_SPLITTING_DURATION / cfg.AUDIO_SPEED)
604-
start, end = 0, cfg.SIG_LENGTH * cfg.AUDIO_SPEED
605653
results = {}
606654

607655
# Status
608656
print(f"Analyzing {fpath}", flush=True)
609657

610-
try:
611-
fileLengthSeconds = audio.get_audio_file_length(fpath)
612-
except Exception as ex:
613-
# Write error log
614-
print(f"Error: Cannot analyze audio file {fpath}. File corrupt?\n", flush=True)
615-
utils.write_error_log(ex)
616-
617-
return None
618-
619658
# Process each chunk
620659
try:
621-
while start < fileLengthSeconds and not np.isclose(start, fileLengthSeconds):
622-
chunks = get_raw_audio_from_file(fpath, start, duration)
623-
samples = []
624-
timestamps = []
625-
626-
for chunk_index, chunk in enumerate(chunks):
627-
# Add to batch
628-
samples.append(chunk)
629-
timestamps.append([round(start, 1), round(end, 1)])
630-
631-
# Advance start and end
632-
start += (cfg.SIG_LENGTH - cfg.SIG_OVERLAP) * cfg.AUDIO_SPEED
633-
end = min(start + cfg.SIG_LENGTH * cfg.AUDIO_SPEED, fileLengthSeconds)
634-
635-
# Check if batch is full or last chunk
636-
if len(samples) < cfg.BATCH_SIZE and chunk_index < len(chunks) - 1:
637-
continue
638-
639-
# Predict
640-
p = predict(samples)
641-
642-
# Add to results
643-
for i in range(len(samples)):
644-
# Get timestamp
645-
s_start, s_end = timestamps[i]
646-
647-
# Get prediction
648-
pred = p[i]
649-
650-
if not cfg.LABELS:
651-
cfg.LABELS = [f"Species-{i}_Species-{i}" for i in range(len(pred))]
652-
653-
# Assign scores to labels
654-
p_labels = [
655-
p
656-
for p in zip(cfg.LABELS, pred, strict=True)
657-
if (cfg.TOP_N or p[1] >= cfg.MIN_CONFIDENCE) and (not cfg.SPECIES_LIST or p[0] in cfg.SPECIES_LIST)
658-
]
660+
for s_start, s_end, pred in iterate_audio_chunks(fpath):
661+
if not cfg.LABELS:
662+
cfg.LABELS = [f"Species-{i}_Species-{i}" for i in range(len(pred))]
659663

660-
# Sort by score
661-
p_sorted = sorted(p_labels, key=operator.itemgetter(1), reverse=True)
664+
# Assign scores to labels
665+
p_labels = [
666+
p for p in zip(cfg.LABELS, pred, strict=True) if (cfg.TOP_N or p[1] >= cfg.MIN_CONFIDENCE) and (not cfg.SPECIES_LIST or p[0] in cfg.SPECIES_LIST)
667+
]
662668

663-
if cfg.TOP_N:
664-
p_sorted = p_sorted[: cfg.TOP_N]
669+
# Sort by score
670+
p_sorted = sorted(p_labels, key=operator.itemgetter(1), reverse=True)
665671

666-
# TODO: hier schon top n oder min conf raussortieren
667-
# Store top 5 results and advance indices
668-
results[str(s_start) + "-" + str(s_end)] = p_sorted
672+
if cfg.TOP_N:
673+
p_sorted = p_sorted[: cfg.TOP_N]
669674

670-
# Clear batch
671-
samples = []
672-
timestamps = []
675+
# TODO: hier schon top n oder min conf raussortieren
676+
# Store top 5 results and advance indices
677+
results[str(s_start) + "-" + str(s_end)] = p_sorted
673678

674679
except Exception as ex:
675680
# Write error log

birdnet_analyzer/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def search_parser():
456456
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, parents=parents)
457457
parser.add_argument("-q", "--queryfile", help="Path to the query file.")
458458
parser.add_argument("-o", "--output", help="Path to the output folder.")
459-
parser.add_argument("--n_results", default=10, help="Number of results to return.")
459+
parser.add_argument("--n_results", default=10, type=int, help="Number of results to return.")
460460

461461
# TODO: use choice argument.
462462
parser.add_argument(

birdnet_analyzer/embeddings/utils.py

Lines changed: 12 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from tqdm import tqdm
1313

1414
import birdnet_analyzer.config as cfg
15-
from birdnet_analyzer import audio, model, utils
16-
from birdnet_analyzer.analyze.utils import get_raw_audio_from_file
15+
from birdnet_analyzer import utils
16+
from birdnet_analyzer.analyze.utils import iterate_audio_chunks
1717
from birdnet_analyzer.embeddings.core import get_database
1818

1919
DATASET_NAME: str = "birdnet_analyzer_dataset"
@@ -30,18 +30,6 @@ def analyze_file(item, db: sqlite_usearch_impl.SQLiteUsearchDB):
3030
fpath: str = item[0]
3131
cfg.set_config(item[1])
3232

33-
offset = 0
34-
duration = cfg.FILE_SPLITTING_DURATION
35-
36-
try:
37-
fileLengthSeconds = int(audio.get_audio_file_length(fpath))
38-
except Exception as ex:
39-
# Write error log
40-
print(f"Error: Cannot analyze audio file {fpath}. File corrupt?\n", flush=True)
41-
utils.write_error_log(ex)
42-
43-
return
44-
4533
# Start time
4634
start_time = datetime.datetime.now()
4735

@@ -52,53 +40,17 @@ def analyze_file(item, db: sqlite_usearch_impl.SQLiteUsearchDB):
5240

5341
# Process each chunk
5442
try:
55-
while offset < fileLengthSeconds:
56-
chunks = get_raw_audio_from_file(fpath, offset, duration)
57-
start, end = offset, cfg.SIG_LENGTH + offset
58-
samples = []
59-
timestamps = []
60-
61-
for c in range(len(chunks)):
62-
# Add to batch
63-
samples.append(chunks[c])
64-
timestamps.append([start, end])
43+
for s_start, s_end, embeddings in iterate_audio_chunks(fpath, embeddings=True):
44+
# Check if embedding already exists
45+
existing_embedding = db.get_embeddings_by_source(DATASET_NAME, source_id, np.array([s_start, s_end]))
6546

66-
# Advance start and end
67-
start += cfg.SIG_LENGTH - cfg.SIG_OVERLAP
68-
end = start + cfg.SIG_LENGTH
47+
if existing_embedding.size == 0:
48+
# Store embeddings
49+
embeddings_source = hoplite.EmbeddingSource(DATASET_NAME, source_id, np.array([s_start, s_end]))
6950

70-
# Check if batch is full or last chunk
71-
if len(samples) < cfg.BATCH_SIZE and c < len(chunks) - 1:
72-
continue
73-
74-
# Prepare sample and pass through model
75-
data = np.array(samples, dtype="float32")
76-
e = model.embeddings(data)
77-
78-
# Add to results
79-
for i in range(len(samples)):
80-
# Get timestamp
81-
s_start, s_end = timestamps[i]
82-
83-
# Check if embedding already exists
84-
existing_embedding = db.get_embeddings_by_source(DATASET_NAME, source_id, np.array([s_start, s_end]))
85-
86-
if existing_embedding.size == 0:
87-
# Get prediction
88-
embeddings = e[i]
89-
90-
# Store embeddings
91-
embeddings_source = hoplite.EmbeddingSource(DATASET_NAME, source_id, np.array([s_start, s_end]))
92-
93-
# Insert into database
94-
db.insert_embedding(embeddings, embeddings_source)
95-
db.commit()
96-
97-
# Reset batch
98-
samples = []
99-
timestamps = []
100-
101-
offset = offset + duration
51+
# Insert into database
52+
db.insert_embedding(embeddings, embeddings_source)
53+
db.commit()
10254

10355
except Exception as ex:
10456
# Write error log
@@ -162,6 +114,7 @@ def create_file_output(output_path: str, db: sqlite_usearch_impl.SQLiteUsearchDB
162114
with open(target_path, "w") as f:
163115
f.write(",".join(map(str, embedding.tolist())))
164116

117+
165118
def run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batchsize, file_output):
166119
### Make sure to comment out appropriately if you are not using args. ###
167120

birdnet_analyzer/gui/embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,8 @@ def render_results(results, page, db_path, exports):
459459
index = i + page * PAGE_SIZE
460460
embedding_source = db.get_embedding_source(r.embedding_id)
461461
file = embedding_source.source_id
462-
offset = embedding_source.offsets[0] * settings["AUDIO_SPEED"]
463-
duration = 3 * settings["AUDIO_SPEED"]
462+
offset = embedding_source.offsets[0]
463+
duration = cfg.SIG_LENGTH * settings["AUDIO_SPEED"]
464464
spec = utils.spectrogram_from_file(
465465
file,
466466
offset=offset,

birdnet_analyzer/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1201,12 +1201,14 @@ def embeddings(sample):
12011201

12021202
load_model(False)
12031203

1204+
sample = np.array(sample, dtype="float32")
1205+
12041206
# Reshape input tensor
12051207
INTERPRETER.resize_tensor_input(INPUT_LAYER_INDEX, [len(sample), *sample[0].shape])
12061208
INTERPRETER.allocate_tensors()
12071209

12081210
# Extract feature embeddings
1209-
INTERPRETER.set_tensor(INPUT_LAYER_INDEX, np.array(sample, dtype="float32"))
1211+
INTERPRETER.set_tensor(INPUT_LAYER_INDEX, sample)
12101212
INTERPRETER.invoke()
12111213

12121214
return INTERPRETER.get_tensor(OUTPUT_LAYER_INDEX)

birdnet_analyzer/search/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def search(
6565
file = embedding_source.source_id
6666
filebasename = os.path.basename(file)
6767
filebasename = os.path.splitext(filebasename)[0]
68-
offset = embedding_source.offsets[0] * audio_speed
68+
offset = embedding_source.offsets[0]
6969
duration = cfg.SIG_LENGTH * audio_speed
7070
sig, rate = audio.open_audio_file(file, offset=offset, duration=duration, sample_rate=None)
7171
result_path = os.path.join(output, f"{r.sort_score:.5f}_{filebasename}_{offset}_{offset + duration}.wav")

birdnet_analyzer/search/utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,7 @@ def get_query_embedding(queryfile_path):
4949
else:
5050
sig_splits = audio.split_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)
5151

52-
samples = sig_splits
53-
data = np.array(samples, dtype="float32")
54-
55-
return model.embeddings(data)
52+
return model.embeddings(sig_splits)
5653

5754

5855
def get_search_results(

0 commit comments

Comments
 (0)