Skip to content

Commit 87b220a

Browse files
authored
disable tf warnings (#769)
1 parent cf5925e commit 87b220a

1 file changed

Lines changed: 20 additions & 24 deletions

File tree

birdnet_analyzer/model.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
# ruff: noqa: PLW0603
22
"""Contains functions to use the BirdNET models."""
33

4+
import logging
45
import os
56
import sys
67
import warnings
78

9+
import absl.logging
810
import numpy as np
911

1012
import birdnet_analyzer.config as cfg
1113
from birdnet_analyzer import utils
1214

13-
SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__))
14-
15-
15+
absl.logging.set_verbosity(absl.logging.ERROR)
16+
logging.getLogger("tensorflow").setLevel(logging.ERROR)
1617
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
1718
os.environ["CUDA_VISIBLE_DEVICES"] = ""
18-
1919
warnings.filterwarnings("ignore")
2020

2121
# Import TFLite from runtime or Tensorflow;
@@ -29,6 +29,7 @@
2929
if not cfg.MODEL_PATH.endswith(".tflite"):
3030
from tensorflow import keras
3131

32+
SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__))
3233
INTERPRETER: tflite.Interpreter = None
3334
C_INTERPRETER: tflite.Interpreter = None
3435
M_INTERPRETER: tflite.Interpreter = None
@@ -38,6 +39,15 @@
3839
EMPTY_CLASS_EXCEPTION_REF = None
3940

4041

42+
def _load_interpreter(mpath, threads):
43+
return tflite.Interpreter(
44+
model_path=mpath,
45+
num_threads=threads,
46+
# XNNPACK disabled, because it does not support variable inputsize anyway (ie batchsize)
47+
experimental_op_resolver_type=tflite.experimental.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES,
48+
)
49+
50+
4151
def get_empty_class_exception():
4252
import keras_tuner.errors
4353

@@ -361,11 +371,7 @@ def upsampling(x: np.ndarray, y: np.ndarray, ratio=0.5, mode="repeat"):
361371
rng = np.random.default_rng(cfg.RANDOM_SEED)
362372

363373
# Determine min number of samples
364-
min_samples = (
365-
int(max(y.sum(axis=0), len(y) - y.sum(axis=0)) * ratio)
366-
if cfg.BINARY_CLASSIFICATION
367-
else int(np.max(y.sum(axis=0)) * ratio)
368-
)
374+
min_samples = int(max(y.sum(axis=0), len(y) - y.sum(axis=0)) * ratio) if cfg.BINARY_CLASSIFICATION else int(np.max(y.sum(axis=0)) * ratio)
369375

370376
x_temp = []
371377
y_temp = []
@@ -516,9 +522,7 @@ def load_model(class_output=True):
516522
if cfg.MODEL_PATH.endswith(".tflite"):
517523
if not INTERPRETER:
518524
# Load TFLite model and allocate tensors.
519-
INTERPRETER = tflite.Interpreter(
520-
model_path=os.path.join(SCRIPT_DIR, cfg.MODEL_PATH), num_threads=cfg.TFLITE_THREADS
521-
)
525+
INTERPRETER = _load_interpreter(os.path.join(SCRIPT_DIR, cfg.MODEL_PATH), cfg.TFLITE_THREADS)
522526
INTERPRETER.allocate_tensors()
523527

524528
# Get input and output tensors.
@@ -553,7 +557,7 @@ def load_custom_classifier():
553557

554558
if cfg.CUSTOM_CLASSIFIER.endswith(".tflite"):
555559
# Load TFLite model and allocate tensors.
556-
C_INTERPRETER = tflite.Interpreter(model_path=cfg.CUSTOM_CLASSIFIER, num_threads=cfg.TFLITE_THREADS)
560+
C_INTERPRETER = _load_interpreter(cfg.CUSTOM_CLASSIFIER, cfg.TFLITE_THREADS)
557561
C_INTERPRETER.allocate_tensors()
558562

559563
# Get input and output tensors.
@@ -585,9 +589,7 @@ def load_meta_model():
585589
global M_OUTPUT_LAYER_INDEX
586590

587591
# Load TFLite model and allocate tensors.
588-
M_INTERPRETER = tflite.Interpreter(
589-
model_path=os.path.join(SCRIPT_DIR, cfg.MDATA_MODEL_PATH), num_threads=cfg.TFLITE_THREADS
590-
)
592+
M_INTERPRETER = _load_interpreter(os.path.join(SCRIPT_DIR, cfg.MDATA_MODEL_PATH), cfg.TFLITE_THREADS)
591593
M_INTERPRETER.allocate_tensors()
592594

593595
# Get input and output tensors.
@@ -633,11 +635,7 @@ def build_linear_classifier(num_labels, input_size, hidden_units=0, dropout=0.0)
633635
model.add(keras.layers.Dropout(dropout))
634636

635637
# Add a hidden layer with L2 regularization
636-
model.add(
637-
keras.layers.Dense(
638-
hidden_units, activation="relu", kernel_regularizer=regularizer, kernel_initializer="he_normal"
639-
)
640-
)
638+
model.add(keras.layers.Dense(hidden_units, activation="relu", kernel_regularizer=regularizer, kernel_initializer="he_normal"))
641639

642640
# Add another batch normalization after the hidden layer
643641
model.add(keras.layers.BatchNormalization())
@@ -813,9 +811,7 @@ def _focal_loss(y_true, y_pred):
813811
)
814812

815813
# Train model
816-
history = classifier.fit(
817-
x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_val, y_val), callbacks=callbacks
818-
)
814+
history = classifier.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_val, y_val), callbacks=callbacks)
819815

820816
return classifier, history
821817

0 commit comments

Comments
 (0)