-
Notifications
You must be signed in to change notification settings - Fork 4
Description
def apply_webrl_format(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
for idx, msg in enumerate(messages):
if idx == 0:
assert msg['role'] == 'user'
intent, obs = msg['content'].split('Round 0')
intent, obs = intent.strip(), obs.strip()
formatted_msg += f'Task Instruction: {intent}\n\nRound 0\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n'
if obs != '** Simplified html **':
assert len(messages) == 1
formatted_msg += obs + '\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n'
else:
formatted_msg += intent + '\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n'
elif msg['role'] == 'assistant':
formatted_msg += msg["content"].strip() + f'\n\nRound {int(idx/2)+1}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n'
elif msg['role'] == 'user':
obs = msg["content"].split(f"Round {int(idx/2)}\n\n")[-1].strip()
formatted_msg += obs + '\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n'
print("apply_web_format 后的messages:")
print(formatted_msg)
return [{'role': 'user', 'content': formatted_msg}]
这里面的0号消息不是system message么?