Skip to content

Commit 7a44df1

Browse files
Multicore for embedding extraction (#774)
* v1 consumer-producer * v1 consumer-producer but mp.Queue * bench * bench cores * updated embeddings gui * update embeddings gui * update failing test * . * added a simple test for the search * . * . * fixed consumer process creation * . * update translations * ...... * closing the db in the test * tests have internal dependency remove for now * ruff * just remove failing tests, ez --------- Co-authored-by: Max Mauermann <max-mauermann@web.de>
1 parent a4e3d81 commit 7a44df1

18 files changed

Lines changed: 468 additions & 265 deletions

File tree

benchmark.sh

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#!/usr/bin/env bash
2+
3+
set -euo pipefail
4+
5+
OUTFILE="benchmark.csv"
6+
TARGET_DIR="/data/embeddings/"
7+
8+
source "./.venv/bin/activate"
9+
10+
if [ ! -f "$OUTFILE" ]; then
11+
echo "version,chunksize,target,cores,bs,sys,user,real" > "$OUTFILE"
12+
fi
13+
14+
run_benchmark() {
15+
local version=$1
16+
local chunksize=$2
17+
local target=$3
18+
local cores=$4
19+
local batchsize=$5
20+
21+
rm -rf "$TARGET_DIR"
22+
23+
export CHUNKSIZE="$chunksize"
24+
export BVERSION="$version"
25+
26+
local real user sys
27+
LC_NUMERIC=C
28+
TIMEFORMAT="%lR %lU %lS"
29+
{ time python -m birdnet_analyzer.embeddings -i "$target" -db /data/embeddings -t "$cores" -b "$batchsize" 2> python_stderr.log ; } 2>timing.tmp
30+
31+
read real user sys <timing.tmp
32+
33+
echo "$version,$chunksize,$target,$cores,$batchsize,$sys,$user,$real" >> "$OUTFILE"
34+
35+
rm -f timing.tmp
36+
}
37+
38+
LARGE_FILES="/data/testing_audio/medium_sized_soundscapes/" # 155
39+
SMALL_FILES="/data/testing_audio/small_files/" # 14018
40+
TEST_CORES="10"
41+
42+
# warmup
43+
run_benchmark "V1" "0" "$LARGE_FILES" "$TEST_CORES" "16"
44+
45+
# larger files
46+
run_benchmark "V1" "0" "$LARGE_FILES" "$TEST_CORES" "16"
47+
run_benchmark "V2" "1" "$LARGE_FILES" "$TEST_CORES" "16"
48+
run_benchmark "V2" "2" "$LARGE_FILES" "$TEST_CORES" "16"
49+
run_benchmark "V2" "3" "$LARGE_FILES" "$TEST_CORES" "16"
50+
run_benchmark "V2" "7" "$LARGE_FILES" "$TEST_CORES" "16" # (155 files // 10 cores) // 2
51+
run_benchmark "V2" "10" "$LARGE_FILES" "$TEST_CORES" "16"
52+
run_benchmark "V2" "15" "$LARGE_FILES" "$TEST_CORES" "16" # 155 files // 10 cores
53+
run_benchmark "V3" "1" "$LARGE_FILES" "$TEST_CORES" "16"
54+
run_benchmark "V3" "2" "$LARGE_FILES" "$TEST_CORES" "16"
55+
run_benchmark "V3" "3" "$LARGE_FILES" "$TEST_CORES" "16"
56+
run_benchmark "V3" "10" "$LARGE_FILES" "$TEST_CORES" "16"
57+
run_benchmark "V3" "10" "$LARGE_FILES" "$TEST_CORES" "16" # (155 files // 10 cores) // 2
58+
run_benchmark "V3" "15" "$LARGE_FILES" "$TEST_CORES" "16" # 155 files // 10 cores
59+
60+
# small files
61+
run_benchmark "V1" "0" "$SMALL_FILES" "$TEST_CORES" "16"
62+
run_benchmark "V2" "1" "$SMALL_FILES" "$TEST_CORES" "16"
63+
run_benchmark "V2" "2" "$SMALL_FILES" "$TEST_CORES" "16"
64+
run_benchmark "V2" "3" "$SMALL_FILES" "$TEST_CORES" "16"
65+
run_benchmark "V2" "7" "$SMALL_FILES" "$TEST_CORES" "16" # (14018 files // 10 cores) // 2
66+
run_benchmark "V2" "700" "$SMALL_FILES" "$TEST_CORES" "16"
67+
run_benchmark "V2" "1401" "$SMALL_FILES" "$TEST_CORES" "16" # 14018 files // 10 cores
68+
run_benchmark "V3" "1" "$SMALL_FILES" "$TEST_CORES" "16"
69+
run_benchmark "V3" "2" "$SMALL_FILES" "$TEST_CORES" "16"
70+
run_benchmark "V3" "3" "$SMALL_FILES" "$TEST_CORES" "16"
71+
run_benchmark "V3" "10" "$SMALL_FILES" "$TEST_CORES" "16"
72+
run_benchmark "V3" "700" "$SMALL_FILES" "$TEST_CORES" "16" # (14018 files // 10 cores) // 2
73+
run_benchmark "V3" "1401" "$SMALL_FILES" "$TEST_CORES" "16" # 14018 files // 10 cores

birdnet_analyzer/cli.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def bandpass_args():
8080

8181
return p
8282

83+
8384
def species_list_args():
8485
"""
8586
Creates an argument parser for species-list arguments.
@@ -111,6 +112,7 @@ def species_list_args():
111112
)
112113
return p
113114

115+
114116
def species_args():
115117
"""
116118
Creates an argument parser for species-related arguments including the species-list arguments.
@@ -325,6 +327,7 @@ def analyzer_parser():
325327
argparse.ArgumentParser: Configured argument parser for the BirdNET Analyzer CLI.
326328
"""
327329
from birdnet_analyzer.analyze import POSSIBLE_ADDITIONAL_COLUMNS_MAP
330+
328331
parents = [
329332
io_args(),
330333
bandpass_args(),
@@ -414,7 +417,7 @@ def embeddings_parser():
414417
argparse.ArgumentParser: Configured argument parser for extracting feature embeddings.
415418
"""
416419

417-
parents = [db_args(), bandpass_args(), audio_speed_args(), overlap_args(), threads_args(), bs_args()]
420+
parents = [db_args(), bandpass_args(), audio_speed_args(), overlap_args(), threads_args(), bs_args(default=8)]
418421

419422
parser = argparse.ArgumentParser(
420423
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
@@ -428,9 +431,7 @@ def embeddings_parser():
428431
help="Path to input file or folder.",
429432
)
430433

431-
parser.add_argument(
432-
"--file_output",
433-
)
434+
parser.add_argument("--file_output", help="Saves all embeddings contained in the database in a csv file.")
434435

435436
return parser
436437

birdnet_analyzer/embeddings/core.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,22 @@ def embeddings(
5050
run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batch_size, file_output)
5151

5252

53-
def get_database(db_path: str):
53+
def try_get_database(db_path: str):
54+
"""Try to get the database object. Creates or opens the databse.
55+
Args:
56+
db: The path to the database.
57+
Returns:
58+
The database object or None if it could not be created or opened.
59+
"""
60+
from perch_hoplite.db import sqlite_usearch_impl
61+
62+
try:
63+
return sqlite_usearch_impl.SQLiteUsearchDB.create(db_path=db_path)
64+
except ValueError:
65+
return None
66+
67+
68+
def get_or_create_database(db_path: str):
5469
"""Get the database object. Creates or opens the databse.
5570
Args:
5671
db: The path to the database.
@@ -67,4 +82,7 @@ def get_database(db_path: str):
6782
db_path=db_path,
6883
usearch_cfg=sqlite_usearch_impl.get_default_usearch_config(embedding_dim=1024), # TODO: dont hardcode this
6984
)
70-
return sqlite_usearch_impl.SQLiteUsearchDB.create(db_path=db_path)
85+
try:
86+
return sqlite_usearch_impl.SQLiteUsearchDB.create(db_path=db_path)
87+
except ValueError:
88+
return sqlite_usearch_impl.SQLiteUsearchDB.create(db_path=db_path, usearch_cfg=sqlite_usearch_impl.get_default_usearch_config(embedding_dim=1024))

birdnet_analyzer/embeddings/utils.py

Lines changed: 136 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""Module used to extract embeddings for samples."""
22

3-
import datetime
3+
import multiprocessing as mp
44
import os
5-
from functools import partial
6-
from multiprocessing import Pool
5+
import time
76

87
import numpy as np
98
from ml_collections import ConfigDict
@@ -14,52 +13,40 @@
1413
import birdnet_analyzer.config as cfg
1514
from birdnet_analyzer import utils
1615
from birdnet_analyzer.analyze.utils import iterate_audio_chunks
17-
from birdnet_analyzer.embeddings.core import get_database
16+
from birdnet_analyzer.embeddings.core import get_or_create_database
1817

1918
DATASET_NAME: str = "birdnet_analyzer_dataset"
19+
COMMIT_BS_SIZE = 512
2020

2121

22-
def analyze_file(item, db: sqlite_usearch_impl.SQLiteUsearchDB):
23-
"""Extracts the embeddings for a file.
24-
25-
Args:
26-
item: (filepath, config)
27-
"""
28-
29-
# Get file path and restore cfg
30-
fpath: str = item[0]
31-
cfg.set_config(item[1])
32-
33-
# Start time
34-
start_time = datetime.datetime.now()
35-
36-
# Status
37-
print(f"Analyzing {fpath}", flush=True)
38-
39-
source_id = fpath
22+
def analyze_file_core(fpath, config):
23+
results = []
24+
cfg.set_config(config)
4025

4126
# Process each chunk
4227
try:
4328
for s_start, s_end, embeddings in iterate_audio_chunks(fpath, embeddings=True):
44-
# Check if embedding already exists
45-
existing_embedding = db.get_embeddings_by_source(DATASET_NAME, source_id, np.array([s_start, s_end]))
46-
47-
if existing_embedding.size == 0:
48-
# Store embeddings
49-
embeddings_source = hoplite.EmbeddingSource(DATASET_NAME, source_id, np.array([s_start, s_end]))
50-
51-
# Insert into database
52-
db.insert_embedding(embeddings, embeddings_source)
53-
db.commit()
29+
results.append((fpath, s_start, s_end, embeddings))
5430
except Exception as ex:
5531
# Write error log
5632
print(f"Error: Cannot analyze audio file {fpath}.", flush=True)
5733
utils.write_error_log(ex)
5834

59-
return
35+
return results
36+
37+
38+
def analyze_file(items):
39+
"""Extracts the embeddings for a file.
40+
41+
Args:
42+
item: (filepath, config)
43+
"""
44+
results = []
45+
46+
for fpath, config in items:
47+
results.extend(analyze_file_core(fpath, config))
6048

61-
delta_time = (datetime.datetime.now() - start_time).total_seconds()
62-
print(f"Finished {fpath} in {delta_time:.2f} seconds", flush=True)
49+
return results
6350

6451

6552
def check_database_settings(db: sqlite_usearch_impl.SQLiteUsearchDB):
@@ -76,13 +63,46 @@ def check_database_settings(db: sqlite_usearch_impl.SQLiteUsearchDB):
7663
db.commit()
7764

7865

79-
def create_file_output(output_path: str, db: sqlite_usearch_impl.SQLiteUsearchDB):
66+
def create_csv_output(output_path: str, database: str):
67+
"""Creates a CSV output for the database.
68+
69+
Args:
70+
output_path: Path to the output file.
71+
db: Database object.
72+
"""
73+
74+
db = get_or_create_database(database)
75+
parent_dir = os.path.dirname(output_path)
76+
77+
if not os.path.exists(parent_dir):
78+
os.makedirs(parent_dir)
79+
80+
embedding_ids = db.get_embedding_ids()
81+
82+
csv_content = "file_path,start,end,embedding\n"
83+
84+
for embedding_id in embedding_ids:
85+
embedding = db.get_embedding(embedding_id)
86+
source = db.get_embedding_source(embedding_id)
87+
88+
start, end = source.offsets
89+
90+
csv_content += f'{source.source_id},{start},{end},"{",".join(map(str, embedding.tolist()))}"\n'
91+
92+
with open(output_path, "w") as f:
93+
f.write(csv_content)
94+
95+
96+
def create_file_output(output_path: str, database: str):
8097
"""Creates a file output for the database.
8198
8299
Args:
83100
output_path: Path to the output file.
84101
db: Database object.
85102
"""
103+
104+
db = get_or_create_database(database)
105+
86106
# Check if output path exists
87107
if not os.path.exists(output_path):
88108
os.makedirs(output_path)
@@ -114,6 +134,52 @@ def create_file_output(output_path: str, db: sqlite_usearch_impl.SQLiteUsearchDB
114134
f.write(",".join(map(str, embedding.tolist())))
115135

116136

137+
def consume_embedding(fpath, s_start, s_end, embeddings, db: sqlite_usearch_impl.SQLiteUsearchDB):
138+
# Check if embedding already exists
139+
existing_embedding = db.get_embeddings_by_source(DATASET_NAME, fpath, np.array([s_start, s_end]))
140+
141+
if existing_embedding.size == 0:
142+
# Store embeddings
143+
embeddings_source = hoplite.EmbeddingSource(DATASET_NAME, fpath, np.array([s_start, s_end]))
144+
145+
# Insert into database
146+
db.insert_embedding(embeddings, embeddings_source)
147+
148+
return True
149+
150+
return False
151+
152+
153+
def consumer(q: mp.Queue, stop_at, database: str):
154+
batchsize = COMMIT_BS_SIZE
155+
batch = 0
156+
break_signal = True
157+
db = get_or_create_database(database)
158+
159+
check_database_settings(db)
160+
161+
while break_signal:
162+
if not q.empty():
163+
results = q.get()
164+
165+
for fpath, s_start, s_end, embeddings in results:
166+
if fpath == stop_at:
167+
break_signal = False
168+
break
169+
170+
if consume_embedding(fpath, s_start, s_end, embeddings, db):
171+
batch += 1
172+
173+
if batch >= batchsize:
174+
db.commit()
175+
batch = 0
176+
else:
177+
time.sleep(0.1)
178+
179+
db.commit()
180+
db.db.close()
181+
182+
117183
def run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batchsize, file_output):
118184
### Make sure to comment out appropriately if you are not using args. ###
119185

@@ -144,8 +210,6 @@ def run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batchs
144210
cfg.CPU_THREADS = 1
145211
cfg.TFLITE_THREADS = max(1, int(threads))
146212

147-
cfg.CPU_THREADS = 1 # TODO: with the current implementation, we can't use more than 1 thread
148-
149213
# Set batch size
150214
cfg.BATCH_SIZE = max(1, int(batchsize))
151215

@@ -155,18 +219,41 @@ def run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batchs
155219
# have its own config. USE LINUX!
156220
flist = [(f, cfg.get_config()) for f in cfg.FILE_LIST]
157221

158-
db = get_database(database)
159-
check_database_settings(db)
160-
161-
# Analyze files
162222
if cfg.CPU_THREADS < 2:
163-
for entry in tqdm(flist):
164-
analyze_file(entry, db)
223+
# Force single core
224+
batchsize = COMMIT_BS_SIZE
225+
batch = 0
226+
db = get_or_create_database(database)
227+
check_database_settings(db)
228+
229+
for fpath, config in tqdm(flist, desc="Files processed"):
230+
for _, s_start, s_end, embeddings in analyze_file_core(fpath, config):
231+
if consume_embedding(fpath, s_start, s_end, embeddings, db):
232+
batch += 1
233+
234+
if batch >= batchsize:
235+
db.commit()
236+
batch = 0
237+
238+
db.commit()
239+
db.db.close()
165240
else:
166-
with Pool(cfg.CPU_THREADS) as p:
167-
tqdm(p.imap(partial(analyze_file, db=db), flist))
241+
chunksize = 2
242+
queue = mp.Queue(maxsize=10_000)
243+
consumer_process = mp.Process(target=consumer, args=(queue, "STOP", database))
244+
consumer_process.start()
245+
246+
# One less process for the pool, because we use one extra for the consumer
247+
with mp.Pool(processes=cfg.CPU_THREADS - 1) as pool:
248+
delta = chunksize
249+
with tqdm(total=len(flist), desc="Files processed") as pbar:
250+
# Instead of chunk_size arg, manual splitting, because this reduces the overhead for the iterable.
251+
for res in pool.imap_unordered(analyze_file, [flist[i : i + delta] for i in range(0, len(flist), delta)], chunksize=1):
252+
queue.put(res)
253+
pbar.update(len(res))
254+
255+
queue.put([("STOP", 0, 0, None)])
256+
consumer_process.join()
168257

169258
if file_output:
170-
create_file_output(file_output, db)
171-
172-
db.db.close()
259+
create_csv_output(file_output, database)

0 commit comments

Comments
 (0)