Skip to content

Commit 35ddc1a

Browse files
committed
update
1 parent 57903b4 commit 35ddc1a

21 files changed

Lines changed: 1600 additions & 173 deletions

editscore/__init__.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import math
99
from . import vie_prompts
1010
import numpy as np
11+
from .json_parser import parse_vlm_output_to_dict
1112

1213
class EditScore:
1314
def __init__(
@@ -91,16 +92,16 @@ def evaluate(self, image_prompts, text_prompt):
9192
max_tries = 2
9293
while SC_dict is False or PQ_dict is False:
9394
tries += 1
94-
guess_if_cannot_parse = True if tries > max_tries else False
95+
give_up_parsing = True if tries > max_tries else False
9596

9697
result_SC = self.model.inference(SC_prompt_final, seed=self.seed + i)
9798
result_PQ = self.model.inference(PQ_prompt_final, seed=self.seed + i)
9899

99100
if result_SC in ["I'm sorry, but I can't assist with that request."] or result_PQ in ["I'm sorry, but I can't assist with that request."]:
100-
guess_if_cannot_parse = True
101+
give_up_parsing = True
101102

102-
SC_dict = mllm_output_to_dict(result_SC, give_up_parsing=guess_if_cannot_parse, text_prompt=text_prompt, score_range=self.score_range)
103-
PQ_dict = mllm_output_to_dict(result_PQ, give_up_parsing=guess_if_cannot_parse, text_prompt=text_prompt, score_range=self.score_range)
103+
SC_dict = mllm_output_to_dict(result_SC, give_up_parsing=give_up_parsing, text_prompt=text_prompt, score_range=self.score_range)
104+
PQ_dict = mllm_output_to_dict(result_PQ, give_up_parsing=give_up_parsing, text_prompt=text_prompt, score_range=self.score_range)
104105

105106
if SC_dict == "rate_limit_exceeded" or PQ_dict == "rate_limit_exceeded":
106107
print("rate_limit_exceeded")
@@ -136,3 +137,63 @@ def evaluate(self, image_prompts, text_prompt):
136137
if self.reduction == "average_first":
137138
output["overall"] = math.sqrt(output["prompt_following"] * output["perceptual_quality"])
138139
return output
140+
141+
142+
def batch_evaluate(self, image_prompts, text_prompt):
143+
SC_prompt = [self.SC_prompt.replace("<instruction>", _text_prompt) for _text_prompt in text_prompt]
144+
145+
SC_prompt = [self.model.prepare_input(image_prompt, _SC_prompt) for image_prompt, _SC_prompt in zip(image_prompts, SC_prompt)]
146+
PQ_prompt = [self.model.prepare_input(image_prompt, self.PQ_prompt) for image_prompt in image_prompts]
147+
148+
outputs_multi_pass = [[] for _ in range(len(image_prompts))]
149+
for i in range(self.num_pass):
150+
results = self.model.batch_inference(SC_prompt + PQ_prompt, seed=self.seed + i)
151+
152+
SC_evaluations = [parse_vlm_output_to_dict(results[i]) for i in range(len(results) // 2)]
153+
PQ_evaluations = [parse_vlm_output_to_dict(results[i]) for i in range(len(results) // 2, len(results))]
154+
155+
for idx, (SC_evaluation, PQ_evaluation) in enumerate(zip(SC_evaluations, PQ_evaluations)):
156+
SC_scores = SC_evaluation["score"]
157+
PQ_scores = PQ_evaluation["score"]
158+
159+
if len(SC_scores) == 0:
160+
SC_scores = [self.score_range / 2]
161+
if len(PQ_scores) == 0:
162+
PQ_scores = [self.score_range / 2]
163+
164+
SC_score = min(SC_scores) / (self.score_range / 10)
165+
PQ_score = min(PQ_scores) / (self.score_range / 10)
166+
if SC_score < 0 or SC_score > 10:
167+
SC_score = self.score_range / 2
168+
if PQ_score < 0 or PQ_score > 10:
169+
PQ_score = self.score_range / 2
170+
O_score = math.sqrt(SC_score * PQ_score)
171+
172+
outputs_multi_pass[idx].append(
173+
{
174+
"SC_score": SC_score,
175+
"PQ_score": PQ_score,
176+
"O_score": O_score,
177+
"SC_score_reasoning": SC_evaluation["reasoning"],
178+
"PQ_score_reasoning": PQ_evaluation["reasoning"],
179+
"SC_raw_output": results[idx],
180+
"PQ_raw_output": results[len(results) // 2 + idx],
181+
}
182+
)
183+
184+
outputs = []
185+
for idx, outputs_per_prompt in enumerate(outputs_multi_pass):
186+
outputs.append(
187+
{
188+
"SC_score": np.mean([output_per_pass["SC_score"] for output_per_pass in outputs_per_prompt]),
189+
"PQ_score": np.mean([output_per_pass["PQ_score"] for output_per_pass in outputs_per_prompt]),
190+
"O_score": np.mean([output_per_pass["O_score"] for output_per_pass in outputs_per_prompt]),
191+
"SC_score_reasoning": outputs_per_prompt[0]["SC_score_reasoning"],
192+
"PQ_score_reasoning": outputs_per_prompt[0]["PQ_score_reasoning"],
193+
"SC_raw_output": outputs_per_prompt[0]["SC_raw_output"],
194+
"PQ_raw_output": outputs_per_prompt[0]["PQ_raw_output"],
195+
}
196+
)
197+
if self.reduction == "average_first":
198+
outputs[-1]["O_score"] = math.sqrt(outputs[-1]["SC_score"] * outputs[-1]["PQ_score"])
199+
return outputs

editscore/json_parser.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import json
2+
import re
3+
from typing import Dict, Any, List, Optional
4+
5+
# ==============================================================================
6+
# HELPER FUNCTIONS (Based on your provided robust fixers)
7+
# ==============================================================================
8+
# For clarity, these are named as internal functions (prefixed with an underscore).
9+
10+
def _fix_json_quotes(s: str) -> str:
11+
"""First-stage repair: handle incorrect quotes and basic structure."""
12+
# Replace Python-style booleans/None with JSON standard
13+
s = re.sub(r'\bTrue\b', 'true', s)
14+
s = re.sub(r'\bFalse\b', 'false', s)
15+
s = re.sub(r'\bNone\b', 'null', s)
16+
17+
# Attempt to replace single quotes with double quotes (a common VLM error)
18+
# This is a high-risk operation that might break the reasoning content,
19+
# but it's worth trying early on.
20+
try:
21+
temp_s = s.replace("'", '"')
22+
json.loads(temp_s)
23+
return temp_s
24+
except json.JSONDecodeError:
25+
# If it's still invalid after replacement, return the original string for the next repair step.
26+
pass
27+
28+
# Add double quotes to keys (e.g., {reasoning: ...} -> {"reasoning": ...})
29+
s = re.sub(r'([\{\s,])(\w+)\s*:', r'\1"\2":', s)
30+
return s
31+
32+
def _repair_reasoning_field_robust(json_str: str) -> str:
33+
"""Second-stage repair: specifically fix unescaped double quotes inside the 'reasoning' field."""
34+
pattern = re.compile(
35+
r'("reasoning"\s*:\s*")' # --- Group 1: "reasoning" : "
36+
r'(.*?)' # --- Group 2: The content (non-greedy)
37+
r'(?="\s*[,}])', # --- Lookahead: find " followed by , or }
38+
re.DOTALL
39+
)
40+
41+
def replacer(match):
42+
prefix = match.group(1)
43+
content = match.group(2)
44+
# In the content, replace all unescaped " with \"
45+
fixed_content = content.replace('"', '\\"')
46+
return prefix + fixed_content
47+
48+
return pattern.sub(replacer, json_str)
49+
50+
def _fallback_extract_and_rebuild(input_str: str) -> str:
51+
"""Final fallback strategy: abandon repair, directly extract information, and rebuild a valid JSON."""
52+
# 1. Extract reasoning
53+
# Find all content between "reasoning": and ,"score":
54+
reasoning_text = ""
55+
reason_match = re.search(r'["\']reasoning["\']\s*:\s*["\']?(.*?)["\']?\s*,\s*["\']score["\']', input_str, re.DOTALL | re.IGNORECASE)
56+
if reason_match:
57+
reasoning_text = reason_match.group(1).strip()
58+
# Clean up any potentially remaining escape characters
59+
reasoning_text = reasoning_text.replace('\\"', '"')
60+
else:
61+
# If not found, assume all text besides the score part is the reasoning.
62+
# First, remove the score part.
63+
score_part_match = re.search(r'["\']score["\']\s*:.*', input_str, re.IGNORECASE)
64+
if score_part_match:
65+
reasoning_text = input_str[:score_part_match.start()].strip()
66+
else:
67+
# If even 'score' cannot be found, assume the entire string is the reasoning.
68+
reasoning_text = input_str
69+
70+
# 2. Extract scores
71+
scores = []
72+
# Prioritize searching after the 'score' keyword.
73+
score_match = re.search(r'["\']score["\']\s*:\s*(.*)', input_str, re.DOTALL | re.IGNORECASE)
74+
search_area = score_match.group(1) if score_match else input_str
75+
76+
# Find all integers or floats.
77+
numbers = re.findall(r'[-+]?\d*\.?\d+', search_area)
78+
if numbers:
79+
scores = [float(num) for num in numbers]
80+
81+
# 3. Rebuild into a standard dictionary and return a JSON string.
82+
rebuilt_data = {
83+
"reasoning": reasoning_text,
84+
"score": scores
85+
}
86+
return json.dumps(rebuilt_data, ensure_ascii=False)
87+
88+
89+
def _format_and_validate_dict(data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
90+
"""Validate and format the parsed dictionary to ensure it meets the final output standard."""
91+
if not isinstance(data, dict):
92+
return None
93+
94+
# Extract reasoning, tolerating case and spelling variations.
95+
reasoning = ""
96+
for key in ["reasoning", "reason", "rationale"]:
97+
if key in data and isinstance(data[key], str):
98+
reasoning = data[key]
99+
break
100+
101+
# Extract score, and ensure it is a list of floats.
102+
scores = []
103+
if 'score' in data:
104+
score_val = data['score']
105+
if isinstance(score_val, list):
106+
scores = [float(s) for s in score_val if isinstance(s, (int, float, str))]
107+
elif isinstance(score_val, (int, float)):
108+
scores = [float(score_val)]
109+
110+
# If any field was found, consider it a success.
111+
if reasoning or scores:
112+
return {"score": scores, "reasoning": reasoning}
113+
114+
return None
115+
116+
# ==============================================================================
117+
# MAIN PARSING FUNCTION
118+
# ==============================================================================
119+
120+
def parse_vlm_output_to_dict(input_string: str) -> Dict[str, Any]:
121+
"""
122+
A highly robust function to parse a VLM's output string into a dictionary
123+
containing 'score' and 'reasoning'.
124+
125+
It uses a multi-stage repair pipeline, progressively degrading from standard
126+
JSON parsing to a final information extraction fallback.
127+
"""
128+
# --- 0. Preprocessing ---
129+
if not input_string or not input_string.strip():
130+
return {"score": [], "reasoning": "Input was empty."}
131+
132+
# Find the substring enclosed by `{}`, which is often the core of the VLM output.
133+
json_match = re.search(r'\{.*\}', input_string, re.DOTALL)
134+
target_str = json_match.group(0) if json_match else input_string.strip()
135+
136+
# --- Repair Pipeline ---
137+
# Apply fixers in order, attempting to parse after each one.
138+
139+
fixer_pipeline = [
140+
lambda s: s, # 1. Try the original string.
141+
_fix_json_quotes, # 2. Fix basic quotes and keywords.
142+
_repair_reasoning_field_robust, # 3. Fix internal quotes in the reasoning field.
143+
]
144+
145+
for fixer in fixer_pipeline:
146+
try:
147+
fixed_str = fixer(target_str)
148+
data = json.loads(fixed_str)
149+
validated_data = _format_and_validate_dict(data)
150+
if validated_data is not None:
151+
return validated_data
152+
except (json.JSONDecodeError, TypeError):
153+
# If it fails, continue to the next fixer.
154+
continue
155+
156+
# --- Final Fallback Strategy ---
157+
# If all repair and parsing attempts fail, activate the information extraction mode.
158+
try:
159+
fallback_str = _fallback_extract_and_rebuild(target_str)
160+
# This function guarantees a valid JSON string, so we can load it directly.
161+
data = json.loads(fallback_str)
162+
# Still run it through the validator to standardize the format.
163+
validated_data = _format_and_validate_dict(data)
164+
if validated_data:
165+
return validated_data
166+
except Exception:
167+
# If even the final fallback strategy fails, return an error message.
168+
pass
169+
170+
return {
171+
"score": [],
172+
"reasoning": f"Failed to parse after all strategies. Original output: '{input_string}'"
173+
}

editscore/mllm_tools/qwen25vl_vllm.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional
22

33
import os
4+
import hashlib
45
import random
56
import time
67
import numpy as np
@@ -56,8 +57,13 @@ def __init__(
5657

5758
if self.enable_lora:
5859
if cache_dir is None:
59-
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
60-
cache_dir = os.path.join(root_dir, "cache", f"{os.path.basename(vlm_model)}_merged_lora")
60+
root_dir = torch.hub.get_dir() # default: ~/.cache/torch/hub
61+
62+
lora_filename = os.path.splitext(os.path.basename(lora_path))[0]
63+
lora_hash = hashlib.md5(lora_path.encode()).hexdigest()[:8]
64+
lora_identifier = f"{lora_filename}_{lora_hash}"
65+
66+
cache_dir = os.path.join(root_dir, "EditScore", f"{os.path.basename(vlm_model)}_merged_lora_{lora_identifier}")
6167

6268
if not os.path.exists(cache_dir):
6369
print(f"Merging LORA to {vlm_model} and saving to {cache_dir}", flush=True)
@@ -120,4 +126,17 @@ def inference(self, messages, seed: Optional[int] = None):
120126
instruction = output.outputs[0].text.strip()
121127
responses.append(instruction)
122128

123-
return responses[0]
129+
return responses[0]
130+
131+
132+
def batch_inference(self, messages, seed: Optional[int] = None):
133+
seed = self.seed if seed is None else seed
134+
sampling_params = SamplingParams(max_tokens=512, temperature=self.temperature, top_p=0.9, top_k=20, seed=seed)
135+
outputs = self.model.generate(messages, sampling_params, use_tqdm=False)
136+
137+
responses = []
138+
for output in outputs:
139+
instruction = output.outputs[0].text.strip()
140+
responses.append(instruction)
141+
142+
return responses

0 commit comments

Comments
 (0)