-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_inference.py
More file actions
209 lines (173 loc) · 6.57 KB
/
test_inference.py
File metadata and controls
209 lines (173 loc) · 6.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#!/usr/bin/env python3
"""
Test script for inference functionality.
Validates that all inference components work correctly.
"""
import os
import sys
import json
import subprocess
import argparse
import logging
from pathlib import Path
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def run_command(cmd, description, capture_output=True):
"""Run a command and return success status."""
logger.info(f"Testing: {description}")
try:
if capture_output:
result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
return True, result.stdout
else:
result = subprocess.run(cmd, shell=True, check=True)
return True, ""
except subprocess.CalledProcessError as e:
logger.error(f"Failed: {description}")
if capture_output and e.stderr:
logger.error(f"Error: {e.stderr}")
return False, ""
def test_config_file():
"""Test if inference config exists and is valid."""
config_path = "inference_config.json"
if not os.path.exists(config_path):
logger.error(f"Config file not found: {config_path}")
return False
try:
with open(config_path, 'r') as f:
config = json.load(f)
required_sections = ['api_models', 'open_source', 'datasets', 'output_dir']
for section in required_sections:
if section not in config:
logger.error(f"Missing config section: {section}")
return False
logger.info("✅ Config file is valid")
return True
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in config: {e}")
return False
def test_inference_scripts():
"""Test that all inference scripts exist and have basic syntax."""
inference_dir = Path("inference")
if not inference_dir.exists():
logger.error("Inference directory not found")
return False
# Find all Python files
python_files = list(inference_dir.rglob("*.py"))
if not python_files:
logger.error("No inference scripts found")
return False
logger.info(f"Found {len(python_files)} inference scripts")
# Test syntax of each script
failed = 0
for script in python_files:
cmd = f"python -m py_compile {script}"
success, _ = run_command(cmd, f"Syntax check: {script.name}")
if not success:
failed += 1
if failed == 0:
logger.info("✅ All inference scripts have valid syntax")
return True
else:
logger.error(f"❌ {failed} scripts have syntax errors")
return False
def test_run_inference_script():
"""Test the main run_inference.py script."""
if not os.path.exists("run_inference.py"):
logger.error("run_inference.py not found")
return False
# Test help output
success, output = run_command("python run_inference.py --help", "run_inference.py help")
if not success:
return False
# Check that help contains expected options
required_args = ['--task', '--model_type', '--inference_type']
for arg in required_args:
if arg not in output:
logger.error(f"Missing argument in help: {arg}")
return False
logger.info("✅ run_inference.py script is functional")
return True
def test_sample_datasets():
"""Test if sample datasets exist for inference."""
datasets = [
"three-vars/icon_only/metadata.csv",
"three-vars/char_only/metadata.csv",
"three-vars/icon_partial/metadata.csv"
]
missing = []
for dataset in datasets:
if not os.path.exists(dataset):
missing.append(dataset)
if missing:
logger.warning(f"Missing sample datasets: {missing}")
logger.warning("Run data generation first: make run-small")
return False
else:
logger.info("✅ Sample datasets available")
return True
def test_dry_run_inference():
"""Test inference with dry run mode."""
if not test_sample_datasets():
logger.info("Skipping dry run test - no sample data")
return True
# Test dry run with different configurations
test_configs = [
{
"task": "visual_equation_solving",
"dataset": "icon_only",
"model_type": "api",
"api_model": "openai",
"inference_type": "direct",
"api_key": "dummy-key"
}
]
for config in test_configs:
cmd_parts = ["python run_inference.py", "--dry_run"]
for key, value in config.items():
cmd_parts.append(f"--{key} {value}")
cmd = " ".join(cmd_parts)
success, output = run_command(cmd, f"Dry run: {config['task']}/{config.get('dataset', '')}")
if not success:
return False
if "Would execute:" not in output:
logger.error("Dry run did not produce expected output")
return False
logger.info("✅ Dry run inference tests passed")
return True
def main():
parser = argparse.ArgumentParser(description="Test inference functionality")
parser.add_argument('--quick', action='store_true', help='Run only quick tests')
args = parser.parse_args()
logger.info("🧪 Testing Math-Eval Inference System")
logger.info("=" * 50)
tests = [
("Config File", test_config_file),
("Inference Scripts", test_inference_scripts),
("Run Inference Script", test_run_inference_script),
("Sample Datasets", test_sample_datasets),
]
if not args.quick:
tests.append(("Dry Run Inference", test_dry_run_inference))
passed = 0
total = len(tests)
for test_name, test_func in tests:
logger.info(f"\n📋 Testing: {test_name}")
if test_func():
passed += 1
else:
logger.error(f"❌ Test failed: {test_name}")
logger.info(f"\n📊 Test Results: {passed}/{total} passed")
if passed == total:
logger.info("🎉 All inference tests passed!")
logger.info("\nNext steps:")
logger.info("1. Configure API keys in inference_config.json")
logger.info("2. Generate datasets: make run-small")
logger.info("3. Run inference: python run_inference.py --help")
return True
else:
logger.error("❌ Some tests failed. Please fix issues before proceeding.")
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)