Skip to content

错误的 assert msg['role'] == 'user' #12

@rababit

Description

@rababit

def apply_webrl_format(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
for idx, msg in enumerate(messages):
if idx == 0:
assert msg['role'] == 'user'
intent, obs = msg['content'].split('Round 0')
intent, obs = intent.strip(), obs.strip()
formatted_msg += f'Task Instruction: {intent}\n\nRound 0\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n'
if obs != '** Simplified html **':
assert len(messages) == 1
formatted_msg += obs + '\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n'
else:
formatted_msg += intent + '\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n'
elif msg['role'] == 'assistant':
formatted_msg += msg["content"].strip() + f'\n\nRound {int(idx/2)+1}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n'
elif msg['role'] == 'user':
obs = msg["content"].split(f"Round {int(idx/2)}\n\n")[-1].strip()
formatted_msg += obs + '\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n'

print("apply_web_format 后的messages:")
print(formatted_msg)

return [{'role': 'user', 'content': formatted_msg}]

这里面的0号消息不是system message么?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions