11"""Module used to extract embeddings for samples."""
22
3- import datetime
3+ import multiprocessing as mp
44import os
5- from functools import partial
6- from multiprocessing import Pool
5+ import time
76
87import numpy as np
98from ml_collections import ConfigDict
1413import birdnet_analyzer .config as cfg
1514from birdnet_analyzer import utils
1615from birdnet_analyzer .analyze .utils import iterate_audio_chunks
17- from birdnet_analyzer .embeddings .core import get_database
16+ from birdnet_analyzer .embeddings .core import get_or_create_database
1817
1918DATASET_NAME : str = "birdnet_analyzer_dataset"
19+ COMMIT_BS_SIZE = 512
2020
2121
22- def analyze_file (item , db : sqlite_usearch_impl .SQLiteUsearchDB ):
23- """Extracts the embeddings for a file.
24-
25- Args:
26- item: (filepath, config)
27- """
28-
29- # Get file path and restore cfg
30- fpath : str = item [0 ]
31- cfg .set_config (item [1 ])
32-
33- # Start time
34- start_time = datetime .datetime .now ()
35-
36- # Status
37- print (f"Analyzing { fpath } " , flush = True )
38-
39- source_id = fpath
22+ def analyze_file_core (fpath , config ):
23+ results = []
24+ cfg .set_config (config )
4025
4126 # Process each chunk
4227 try :
4328 for s_start , s_end , embeddings in iterate_audio_chunks (fpath , embeddings = True ):
44- # Check if embedding already exists
45- existing_embedding = db .get_embeddings_by_source (DATASET_NAME , source_id , np .array ([s_start , s_end ]))
46-
47- if existing_embedding .size == 0 :
48- # Store embeddings
49- embeddings_source = hoplite .EmbeddingSource (DATASET_NAME , source_id , np .array ([s_start , s_end ]))
50-
51- # Insert into database
52- db .insert_embedding (embeddings , embeddings_source )
53- db .commit ()
29+ results .append ((fpath , s_start , s_end , embeddings ))
5430 except Exception as ex :
5531 # Write error log
5632 print (f"Error: Cannot analyze audio file { fpath } ." , flush = True )
5733 utils .write_error_log (ex )
5834
59- return
35+ return results
36+
37+
38+ def analyze_file (items ):
39+ """Extracts the embeddings for a file.
40+
41+ Args:
42+ item: (filepath, config)
43+ """
44+ results = []
45+
46+ for fpath , config in items :
47+ results .extend (analyze_file_core (fpath , config ))
6048
61- delta_time = (datetime .datetime .now () - start_time ).total_seconds ()
62- print (f"Finished { fpath } in { delta_time :.2f} seconds" , flush = True )
49+ return results
6350
6451
6552def check_database_settings (db : sqlite_usearch_impl .SQLiteUsearchDB ):
@@ -76,13 +63,46 @@ def check_database_settings(db: sqlite_usearch_impl.SQLiteUsearchDB):
7663 db .commit ()
7764
7865
79- def create_file_output (output_path : str , db : sqlite_usearch_impl .SQLiteUsearchDB ):
66+ def create_csv_output (output_path : str , database : str ):
67+ """Creates a CSV output for the database.
68+
69+ Args:
70+ output_path: Path to the output file.
71+ db: Database object.
72+ """
73+
74+ db = get_or_create_database (database )
75+ parent_dir = os .path .dirname (output_path )
76+
77+ if not os .path .exists (parent_dir ):
78+ os .makedirs (parent_dir )
79+
80+ embedding_ids = db .get_embedding_ids ()
81+
82+ csv_content = "file_path,start,end,embedding\n "
83+
84+ for embedding_id in embedding_ids :
85+ embedding = db .get_embedding (embedding_id )
86+ source = db .get_embedding_source (embedding_id )
87+
88+ start , end = source .offsets
89+
90+ csv_content += f'{ source .source_id } ,{ start } ,{ end } ,"{ "," .join (map (str , embedding .tolist ()))} "\n '
91+
92+ with open (output_path , "w" ) as f :
93+ f .write (csv_content )
94+
95+
96+ def create_file_output (output_path : str , database : str ):
8097 """Creates a file output for the database.
8198
8299 Args:
83100 output_path: Path to the output file.
84101 db: Database object.
85102 """
103+
104+ db = get_or_create_database (database )
105+
86106 # Check if output path exists
87107 if not os .path .exists (output_path ):
88108 os .makedirs (output_path )
@@ -114,6 +134,52 @@ def create_file_output(output_path: str, db: sqlite_usearch_impl.SQLiteUsearchDB
114134 f .write ("," .join (map (str , embedding .tolist ())))
115135
116136
137+ def consume_embedding (fpath , s_start , s_end , embeddings , db : sqlite_usearch_impl .SQLiteUsearchDB ):
138+ # Check if embedding already exists
139+ existing_embedding = db .get_embeddings_by_source (DATASET_NAME , fpath , np .array ([s_start , s_end ]))
140+
141+ if existing_embedding .size == 0 :
142+ # Store embeddings
143+ embeddings_source = hoplite .EmbeddingSource (DATASET_NAME , fpath , np .array ([s_start , s_end ]))
144+
145+ # Insert into database
146+ db .insert_embedding (embeddings , embeddings_source )
147+
148+ return True
149+
150+ return False
151+
152+
153+ def consumer (q : mp .Queue , stop_at , database : str ):
154+ batchsize = COMMIT_BS_SIZE
155+ batch = 0
156+ break_signal = True
157+ db = get_or_create_database (database )
158+
159+ check_database_settings (db )
160+
161+ while break_signal :
162+ if not q .empty ():
163+ results = q .get ()
164+
165+ for fpath , s_start , s_end , embeddings in results :
166+ if fpath == stop_at :
167+ break_signal = False
168+ break
169+
170+ if consume_embedding (fpath , s_start , s_end , embeddings , db ):
171+ batch += 1
172+
173+ if batch >= batchsize :
174+ db .commit ()
175+ batch = 0
176+ else :
177+ time .sleep (0.1 )
178+
179+ db .commit ()
180+ db .db .close ()
181+
182+
117183def run (audio_input , database , overlap , audio_speed , fmin , fmax , threads , batchsize , file_output ):
118184 ### Make sure to comment out appropriately if you are not using args. ###
119185
@@ -144,8 +210,6 @@ def run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batchs
144210 cfg .CPU_THREADS = 1
145211 cfg .TFLITE_THREADS = max (1 , int (threads ))
146212
147- cfg .CPU_THREADS = 1 # TODO: with the current implementation, we can't use more than 1 thread
148-
149213 # Set batch size
150214 cfg .BATCH_SIZE = max (1 , int (batchsize ))
151215
@@ -155,18 +219,41 @@ def run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batchs
155219 # have its own config. USE LINUX!
156220 flist = [(f , cfg .get_config ()) for f in cfg .FILE_LIST ]
157221
158- db = get_database (database )
159- check_database_settings (db )
160-
161- # Analyze files
162222 if cfg .CPU_THREADS < 2 :
163- for entry in tqdm (flist ):
164- analyze_file (entry , db )
223+ # Force single core
224+ batchsize = COMMIT_BS_SIZE
225+ batch = 0
226+ db = get_or_create_database (database )
227+ check_database_settings (db )
228+
229+ for fpath , config in tqdm (flist , desc = "Files processed" ):
230+ for _ , s_start , s_end , embeddings in analyze_file_core (fpath , config ):
231+ if consume_embedding (fpath , s_start , s_end , embeddings , db ):
232+ batch += 1
233+
234+ if batch >= batchsize :
235+ db .commit ()
236+ batch = 0
237+
238+ db .commit ()
239+ db .db .close ()
165240 else :
166- with Pool (cfg .CPU_THREADS ) as p :
167- tqdm (p .imap (partial (analyze_file , db = db ), flist ))
241+ chunksize = 2
242+ queue = mp .Queue (maxsize = 10_000 )
243+ consumer_process = mp .Process (target = consumer , args = (queue , "STOP" , database ))
244+ consumer_process .start ()
245+
246+ # One less process for the pool, because we use one extra for the consumer
247+ with mp .Pool (processes = cfg .CPU_THREADS - 1 ) as pool :
248+ delta = chunksize
249+ with tqdm (total = len (flist ), desc = "Files processed" ) as pbar :
250+ # Instead of chunk_size arg, manual splitting, because this reduces the overhead for the iterable.
251+ for res in pool .imap_unordered (analyze_file , [flist [i : i + delta ] for i in range (0 , len (flist ), delta )], chunksize = 1 ):
252+ queue .put (res )
253+ pbar .update (len (res ))
254+
255+ queue .put ([("STOP" , 0 , 0 , None )])
256+ consumer_process .join ()
168257
169258 if file_output :
170- create_file_output (file_output , db )
171-
172- db .db .close ()
259+ create_csv_output (file_output , database )
0 commit comments