-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgpt_test.py
More file actions
176 lines (150 loc) · 7.33 KB
/
gpt_test.py
File metadata and controls
176 lines (150 loc) · 7.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import argparse
import json
import asyncio
import re
import os
from openai import AzureOpenAI
from azure.identity import (
DefaultAzureCredential,
ChainedTokenCredential,
AzureCliCredential,
get_bearer_token_provider,
)
scope = "api://trapi/.default"
credential = get_bearer_token_provider(
ChainedTokenCredential(
AzureCliCredential(),
DefaultAzureCredential(
exclude_cli_credential=True,
# Exclude other credentials we are not interested in.
exclude_environment_credential=True,
exclude_shared_token_cache_credential=True,
exclude_developer_cli_credential=True,
exclude_powershell_credential=True,
exclude_interactive_browser_credential=True,
exclude_visual_studio_code_credentials=True,
# DEFAULT_IDENTITY_CLIENT_ID is a variable exposed in
# Azure ML Compute jobs that has the client id of the
# user-assigned managed identity in it.
# See https://learn.microsoft.com/en-us/azure/machine-learning/how-to-identity-based-service-authentication#compute-cluster
# In case it is not set the ManagedIdentityCredential will
# default to using the system-assigned managed identity, if any.
managed_identity_client_id=os.environ.get("DEFAULT_IDENTITY_CLIENT_ID"),
),
),
scope,
)
api_version = "2024-10-21" # Ensure this is a valid API version see: https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation#latest-ga-api-release
model_name = "gpt-4o" # Ensure this is a valid model name
model_version = "2024-11-20" # Ensure this is a valid model version
deployment_name = re.sub(
r"[^a-zA-Z0-9-_]", "", f"{model_name}_{model_version}"
) # If your Endpoint doesn't have harmonized deployment names, you can use the deployment name directly: see: https://aka.ms/trapi/models
instance = "gcr/shared" # See https://aka.ms/trapi/models for the instance name, remove /openai (library adds it implicitly)
endpoint = f"https://trapi.research.microsoft.com/{instance}"
client = AzureOpenAI(
azure_endpoint=endpoint,
azure_ad_token_provider=credential,
api_version=api_version,
)
def read_data(file_path):
"""Read question and answer data from a file"""
def read_json_lines(file_path):
with open(file_path, "r", encoding="utf-8") as f:
return [json.loads(line) for line in f]
def read_json(file_path):
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
if file_path.endswith(".jsonl"):
return read_json_lines(file_path)
elif file_path.endswith(".json"):
return read_json(file_path)
else:
raise ValueError(f"Unsupported file format: {file_path}")
async def call_gpt_api(item, retry_attempts=3):
"""Asynchronously call Azure OpenAI API to generate reasoning process for one item"""
question = item["input"]
context = item["context"]
answer = item["answers"][0]
# Construct chat messages
messages = [
{
"role": "user",
"content": f"Question: {question}\n\nContext: {context}\n\nAnswer: {answer}\n\nProvide a detailed reasoning process to arrive at the answer based on the given context.",
}
]
for attempt in range(retry_attempts):
try:
# Asynchronous GPT API call using Azure OpenAI client
response = await asyncio.to_thread(
client.chat.completions.create, model=deployment_name, messages=messages
)
reasoning = response.choices[0].message.content.strip()
return {
"index": item["index"],
"input": question,
"context": context,
"answer": answer,
"reasoning": reasoning,
}
except Exception as e:
print(
f"Error processing item {item['index']} on attempt {attempt + 1}: {e}"
)
if attempt < retry_attempts - 1:
await asyncio.sleep(2**attempt) # Exponential backoff
else:
return {
"index": item["index"],
"input": question,
"context": context,
"answer": answer,
"reasoning": "Error: " + str(e),
}
async def generate_inference(input_data):
"""Generate reasoning process for all items asynchronously"""
tasks = [call_gpt_api(item) for item in input_data]
results = await asyncio.gather(*tasks)
return results
def save_results(file_path, results):
"""Save reasoning results to a file"""
if file_path.endswith(".jsonl"):
with open(file_path, "w", encoding="utf-8") as f:
for item in results:
f.write(json.dumps(item) + "\n")
elif: file_path.endswith(".json"):
with open(file_path, "w", encoding="utf-8") as f:
json.dump(results, f)
else:
raise ValueError(f"Unsupported file format: {file_path}")
def parse_args():
parser = argparse.ArgumentParser(
description="Generate reasoning process for question-answer pairs"
)
parser.add_argument("--input_file", type=str, required=True, help="Input file path")
parser.add_argument(
"--output_file", type=str, required=True, help="Output file path"
)
parser.add_argument(
"--prompt_template_type", type=str, default="basic", help="Prompt template type"
)
return parser.parse_args()
prompt_template_dict = {
"basic": "Question: {question}\n\nContext: {context}\n\nAnswer: {answer}\n\nProvide a detailed reasoning process to arrive at the answer based on the given context.",
"normal": "Question: {question}\n\nContext: {context}\n\nAnswer: {answer}\n\nYour task:\nPlease produce a clear and logically sound explanation that shows how the answer is derived from the context. Finally, present your final conclusion after \"Final answer:\".",
"cot": " Question: {question}\n\nContext: {context}\n\nAnswer: {answer}\n\nYour task:\nPlease produce a step-by-step reasoning that uncovers the path from the context to the final answer. Clearly demonstrate each inference or sub-action in your explanation. Finally, present your final conclusion after \"Final answer:\"",
"cot-cite": "Question: {question}\n\nContext: {context}\n\nAnswer: {answer}\n\nYour task:\nPlease produce a structured, step-by-step reasoning that references any relevant parts of the context in quotes ("") whenever you use them. Finally, present your final conclusion after \"Final answer:\".",
"mcts": "Question: {question}\n\nContext: {context}\n\nAnswer: {answer}\n\nYour task:\nPlease adopt a multi-phase approach to thoroughly examine the given context, refining your ideas at each stage. Provide the reasoning details step by step. Finally, present your final conclusion after \"Final answer:\".",
}
if __name__ == "__main__":
# Read input data
args = parse_args()
input_file = args.input_file
output_file = args.output_file
prompt_template = prompt_template_dict[args.prompt_template_type]
input_data = read_data(input_file)
# Generate reasoning process asynchronously
results = asyncio.run(generate_inference(input_data))
# Save results
save_results(output_file, results)
print(f"Reasoning results saved to {output_file}")