diff --git a/src/muse/evaluation/metrics.py b/src/muse/evaluation/metrics.py index 0d7f773..2c07ef7 100644 --- a/src/muse/evaluation/metrics.py +++ b/src/muse/evaluation/metrics.py @@ -120,7 +120,9 @@ def compute_cometkiwi( try: model_path = download_model("Unbabel/wmt22-cometkiwi-da") LOADED_METRICS["cometkiwi"] = load_from_checkpoint(model_path) - except KeyError as e: # download_model catches all exceptions and re-raises as KeyError + except ( + KeyError + ) as e: # download_model catches all exceptions and re-raises as KeyError msg = ( "Authentication required for CometKiwi model. " "Please:\n" @@ -139,6 +141,6 @@ def compute_cometkiwi( # Predict returns a Prediction object; access the first score model_output = model.predict(data, batch_size=1, gpus=gpus) # The Prediction object can be indexed to get individual scores - score = model_output[0] + score = model_output[0][0] return score