forked from pasquini-dario/LLMmap
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathadd_new_template.py
More file actions
46 lines (33 loc) · 1.85 KB
/
add_new_template.py
File metadata and controls
46 lines (33 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import argparse
import sys
import tqdm
from LLMmap.inference import load_LLMmap
from LLMmap.dataset_maker import make_dataset_entries_for_new_llm
from LLMmap.prompt_configuration import PromptConfFactory, TRAIN
from LLMmap.llm import load_llm
def main():
parser = argparse.ArgumentParser(description="Generate templates for a new LLM using LLMmap and add it to the template file.")
parser.add_argument('new_llm_name', type=str, help='Name or path of the new LLM')
parser.add_argument('new_llm_type', type=int, help='0:Hugging Face, 1:OpenAI, 2:Anthropic, 3:OpenRouter')
parser.add_argument('--prompt_conf_path', type=str, default='./confs/prompt_configurations/', help='Path to prompt configuration directory')
parser.add_argument('--llmmap_path', type=str, default='./data/pretrained_models/default/', help='Path to the pretrained LLMmap model')
parser.add_argument('--num_prompt_confs', type=int, default=100, help='Number of prompt configurations to sample')
args = parser.parse_args()
conf, llmmap = load_LLMmap(args.llmmap_path)
if not conf['is_open']:
print("Applicable to only open-set inference model. Aborting...")
sys.exit(1)
if not llmmap.ready:
print("No templates found for the model. Aborting...")
sys.exit(1)
if args.new_llm_name in llmmap.templates_map:
print(f"Template for {args.new_llm_name} has already be computed. Aborting...")
sys.exit(1)
new_llm = load_llm(args.new_llm_name, args.new_llm_type)
pc = PromptConfFactory(args.prompt_conf_path)
prompt_confs = pc.sample(args.num_prompt_confs, pool=TRAIN)
entries = make_dataset_entries_for_new_llm(new_llm, conf['queries'], prompt_confs)
new_template = llmmap.compute_template(entries)
llmmap.add_entry_and_save_templates(new_llm.llm_name, new_template)
if __name__ == '__main__':
main()