-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrun_inference.py
More file actions
189 lines (160 loc) · 7.54 KB
/
run_inference.py
File metadata and controls
189 lines (160 loc) · 7.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
#!/usr/bin/env python3
"""
Unified inference runner for Math-Eval dataset.
Supports both API-based models (OpenAI, Gemini) and open-source models.
"""
import argparse
import json
import os
import sys
import logging
from pathlib import Path
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def load_config(config_path):
"""Load configuration from JSON file."""
try:
with open(config_path, 'r') as f:
return json.load(f)
except FileNotFoundError:
logger.error(f"Configuration file not found: {config_path}")
sys.exit(1)
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in configuration file: {e}")
sys.exit(1)
def validate_arguments(args, config):
"""Validate command line arguments against configuration."""
# Validate task type
valid_tasks = ['counting', 'visual_equation_solving']
if args.task not in valid_tasks:
logger.error(f"Invalid task: {args.task}. Must be one of {valid_tasks}")
return False
# Validate dataset type for visual equation solving
if args.task == 'visual_equation_solving':
valid_datasets = ['char_only', 'icon_only', 'icon_partial']
if args.dataset not in valid_datasets:
logger.error(f"Invalid dataset: {args.dataset}. Must be one of {valid_datasets}")
return False
# Validate model type
if args.model_type == 'api':
if args.api_model not in config['api_models']:
logger.error(f"API model {args.api_model} not configured")
return False
if not args.api_key:
logger.error("API key required for API models")
return False
elif args.model_type == 'open_source':
if args.os_model not in config['open_source']['models']:
logger.error(f"Open source model {args.os_model} not configured")
return False
return True
def get_inference_script_path(task, dataset, model_type, inference_type):
"""Get the path to the appropriate inference script."""
base_path = Path(__file__).parent / "inference"
if task == 'counting':
if model_type == 'api':
script_path = base_path / "counting" / "api-models" / f"inference_{inference_type}.py"
else:
script_path = base_path / "counting" / "open-source" / f"inference_{inference_type}.py"
else: # visual_equation_solving
if model_type == 'api':
if dataset == 'char_only':
script_name = f"inference_ocr_only_{inference_type}.py" if inference_type != 'direct' else "inference_ocr_only.py"
else: # icon_only or icon_partial
script_name = f"inference_icon_{inference_type}.py" if inference_type != 'direct' else "inference_icon.py"
script_path = base_path / "visual_equation_solving" / dataset / "api-models" / script_name
else:
if dataset == 'char_only':
script_name = f"inference_ocr_only_{inference_type}.py" if inference_type != 'direct' else "inference_ocr_only.py"
else: # icon_only or icon_partial
script_name = f"inference_icon_{inference_type}.py" if inference_type != 'direct' else "inference_icon.py"
script_path = base_path / "visual_equation_solving" / dataset / "open-source" / script_name
return script_path
def build_command(args, config, script_path):
"""Build the command to execute the inference script."""
cmd = [sys.executable, str(script_path)]
# Add common arguments
if args.task == 'counting':
dataset_config = config['datasets']['counting']
else:
dataset_config = config['datasets']['visual_equation_solving'][args.dataset]
cmd.extend(['--metadata', dataset_config['metadata_file']])
# Build output filename
if args.task == 'counting':
output_filename = f"counting_{args.model_type}_{args.inference_type}_results.csv"
else:
output_filename = f"{args.task}_{args.dataset}_{args.model_type}_{args.inference_type}_results.csv"
cmd.extend(['--output', os.path.join(config['output_dir'], output_filename)])
if args.model_type == 'api':
cmd.extend(['--model', args.api_model])
cmd.extend(['--api_key', args.api_key])
else:
cmd.extend(['--model', args.os_model])
# Add image path if specified
if args.image_dir:
cmd.extend(['--image_path', args.image_dir])
else:
cmd.extend(['--image_path', dataset_config['image_dir']])
return cmd
def main():
parser = argparse.ArgumentParser(description="Run inference on Math-Eval dataset")
parser.add_argument('--config', type=str, default='inference_config.json',
help='Path to configuration file')
parser.add_argument('--task', type=str, required=True,
choices=['counting', 'visual_equation_solving'],
help='Type of task to run inference on')
parser.add_argument('--dataset', type=str,
choices=['char_only', 'icon_only', 'icon_partial'],
help='Dataset type (required for visual_equation_solving)')
parser.add_argument('--model_type', type=str, required=True,
choices=['api', 'open_source'],
help='Type of model to use')
parser.add_argument('--api_model', type=str,
choices=['openai', 'gemini'],
help='API model to use (required if model_type=api)')
parser.add_argument('--os_model', type=str,
choices=['llama_vision', 'molmo', 'qwen2_vl'],
help='Open source model to use (required if model_type=open_source)')
parser.add_argument('--api_key', type=str,
help='API key (required if model_type=api)')
parser.add_argument('--inference_type', type=str, required=True,
choices=['direct', 'two_step', 'cot'],
help='Type of inference to run')
parser.add_argument('--image_dir', type=str,
help='Custom image directory (overrides config)')
parser.add_argument('--dry_run', action='store_true',
help='Print command without executing')
args = parser.parse_args()
# Load configuration
config = load_config(args.config)
# Validate arguments
if not validate_arguments(args, config):
sys.exit(1)
# Create output directory
os.makedirs(config['output_dir'], exist_ok=True)
# Get inference script path
script_path = get_inference_script_path(args.task, args.dataset, args.model_type, args.inference_type)
if not script_path.exists():
logger.error(f"Inference script not found: {script_path}")
sys.exit(1)
# Build command
cmd = build_command(args, config, script_path)
if args.dry_run:
logger.info(f"Would execute: {' '.join(cmd)}")
return
# Execute command
logger.info(f"Running inference: {' '.join(cmd)}")
import subprocess
try:
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
logger.info("Inference completed successfully")
if result.stdout:
print(result.stdout)
except subprocess.CalledProcessError as e:
logger.error(f"Inference failed with exit code {e.returncode}")
if e.stderr:
logger.error(f"Error output: {e.stderr}")
sys.exit(1)
if __name__ == "__main__":
main()