-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathfix_inference_scripts.py
More file actions
177 lines (141 loc) · 6.12 KB
/
fix_inference_scripts.py
File metadata and controls
177 lines (141 loc) · 6.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#!/usr/bin/env python3
"""
Script to standardize and fix inference scripts in the math-eval repository.
This ensures all scripts have consistent CLI interfaces and error handling.
"""
import re
import argparse
from pathlib import Path
def fix_openai_api_call(content):
"""Fix OpenAI API calls to use the new format."""
# Replace old API format with new format
old_pattern = r'openai\.ChatCompletion\.create\('
new_replacement = 'openai.chat.completions.create('
content = re.sub(old_pattern, new_replacement, content)
# Fix response access pattern
old_response = r'response\["choices"\]\[0\]\["message"\]\["content"\]'
new_response = 'response.choices[0].message.content'
content = re.sub(old_response, new_response, content)
return content
def fix_gemini_model_call(content):
"""Fix Gemini model calls to use consistent model names."""
# Update model name to latest version
content = re.sub(r"genai\.GenerativeModel\(['\"]gemini-2\.0-flash['\"]\)",
"genai.GenerativeModel('gemini-1.5-pro')", content)
return content
def add_error_handling(content):
"""Add comprehensive error handling to inference scripts."""
# Add imports if missing
imports_to_add = [
"import logging",
"import sys"
]
for import_line in imports_to_add:
if import_line not in content:
content = import_line + "\n" + content
# Add logging setup
logging_setup = """
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
"""
if "logging.basicConfig" not in content:
# Insert after imports
import_end = content.find('\n\n')
if import_end != -1:
content = content[:import_end] + logging_setup + content[import_end:]
return content
def fix_argparse_section(content):
"""Standardize argparse sections across all scripts."""
# Standard argparse template
argparse_template = '''def main():
parser = argparse.ArgumentParser(description="Run inference on math-eval dataset")
parser.add_argument('--metadata', type=str, required=True,
help="Path to the metadata CSV file")
parser.add_argument('--output', type=str, required=True,
help="Path to save the results CSV file")
parser.add_argument('--model', type=str, required=True,
help="Model to use for inference")
parser.add_argument('--api_key', type=str,
help="API key for the selected model (required for API models)")
parser.add_argument('--image_path', type=str,
help="Path to the directory containing images")
args = parser.parse_args()'''
# Find and replace the main function definition
main_pattern = r'def main\(\):\s*parser = argparse\.ArgumentParser.*?args = parser\.parse_args\(\)'
if re.search(main_pattern, content, re.DOTALL):
content = re.sub(main_pattern, argparse_template, content, flags=re.DOTALL)
return content
def add_requirements_validation(content):
"""Add validation for required packages."""
validation_code = '''
def validate_requirements():
"""Validate that required packages are installed."""
required_packages = {
'pandas': 'pandas',
'PIL': 'Pillow',
'tqdm': 'tqdm'
}
missing_packages = []
for package, pip_name in required_packages.items():
try:
__import__(package)
except ImportError:
missing_packages.append(pip_name)
if missing_packages:
logger.error(f"Missing required packages: {', '.join(missing_packages)}")
logger.error("Install with: pip install " + " ".join(missing_packages))
sys.exit(1)
'''
# Add validation function before main
if "def validate_requirements" not in content:
main_pos = content.find("def main()")
if main_pos != -1:
content = content[:main_pos] + validation_code + "\n" + content[main_pos:]
# Add validation call in main
if "validate_requirements()" not in content:
content = content.replace("def main():", "def main():\n validate_requirements()")
return content
def process_inference_script(file_path):
"""Process a single inference script to standardize it."""
print(f"Processing: {file_path}")
try:
with open(file_path, 'r') as f:
content = f.read()
# Apply fixes
content = fix_openai_api_call(content)
content = fix_gemini_model_call(content)
content = add_error_handling(content)
content = fix_argparse_section(content)
content = add_requirements_validation(content)
# Write back the fixed content
with open(file_path, 'w') as f:
f.write(content)
print(f"✅ Fixed: {file_path}")
except Exception as e:
print(f"❌ Error processing {file_path}: {e}")
def main():
parser = argparse.ArgumentParser(description="Standardize inference scripts")
parser.add_argument('--inference_dir', type=str, default='inference',
help='Path to inference directory')
parser.add_argument('--dry_run', action='store_true',
help='Show files that would be processed without modifying them')
args = parser.parse_args()
inference_path = Path(args.inference_dir)
if not inference_path.exists():
print(f"Inference directory not found: {inference_path}")
return
# Find all Python files in inference directory
python_files = list(inference_path.rglob("*.py"))
print(f"Found {len(python_files)} Python files in {inference_path}")
if args.dry_run:
print("Files that would be processed:")
for file_path in python_files:
print(f" {file_path}")
return
# Process each file
for file_path in python_files:
process_inference_script(file_path)
print(f"\n✅ Processed {len(python_files)} inference scripts")
if __name__ == "__main__":
main()