|
| 1 | + |
| 2 | +from dataclasses import dataclass, field |
| 3 | +from guidance import models, gen, block |
| 4 | +from typing import ClassVar, Optional |
| 5 | +import pydantic |
| 6 | +from pydantic import create_model |
| 7 | +from functools import cached_property |
| 8 | +import json |
| 9 | +import os |
| 10 | + |
| 11 | +import guidance |
| 12 | +from guidance import one_or_more, select, zero_or_more |
| 13 | +from simple_parsing import Serializable |
| 14 | +import numpy as np |
| 15 | + |
| 16 | +from mmz.agents.with_guidance import GuidanceLlamaCppConfig |
| 17 | +from mmz.agents import tools as mzt |
| 18 | + |
| 19 | + |
| 20 | +# stateless=True indicates this function does not depend on LLM generations |
| 21 | +@guidance(stateless=True) |
| 22 | +def reference_selection(lm, n_total_refs, selection_name='selected_ixes'): |
| 23 | + nums = [str(ii) for ii in range(n_total_refs)] |
| 24 | + return lm + one_or_more(select(nums, name=selection_name)) |
| 25 | + |
| 26 | + |
| 27 | +@guidance(stateless=True) |
| 28 | +def operator(lm): |
| 29 | + return lm + select(['+' , '*', '**', '/', '-']) |
| 30 | + |
| 31 | + |
| 32 | +def get_summary_relevance_prompt(query: str, summaries: list[dict]) -> str: |
| 33 | + import json |
| 34 | + from datetime import datetime |
| 35 | + current_time = datetime.now().strftime("%Y-%m-%d") |
| 36 | + shuffled_ixes = list(range(len(summaries))) |
| 37 | + np.random.shuffle(shuffled_ixes) |
| 38 | + print("shuffled_ixes:", str(shuffled_ixes)) |
| 39 | + prompt = f"""Analyze these search results and provide a ranked list of the most relevant ones. |
| 40 | +
|
| 41 | +IMPORTANT: Evaluate and rank based on these criteria (in order of importance): |
| 42 | +1. Timeliness - current/recent information as of {current_time} |
| 43 | +2. Direct relevance to query: "{query}" |
| 44 | +3. Source reliability (prefer official sources, established websites) |
| 45 | +4. Factual accuracy (cross-reference major claims) |
| 46 | +
|
| 47 | +Search results to evaluate: |
| 48 | + {json.dumps([{'reference index': ii, |
| 49 | + 'title': s['title'], |
| 50 | + 'summary': s['summary']} |
| 51 | + for ii, s in enumerate(summaries)] |
| 52 | + , indent=2)} |
| 53 | +
|
| 54 | +Return ONLY a JSON array of the 0-based reference index, ranked from most to least relevant. |
| 55 | +Include ONLY indices that meet ALL criteria, with the most relevant first. |
| 56 | +You should list all {len(summaries)} indices in your response. |
| 57 | +You should not output any number larger than {len(summaries) - 1} |
| 58 | +Respond with ONLY the JSON array, no other text.""" |
| 59 | +#Example response (yours should be different!): {shuffled_ixes} |
| 60 | + return prompt |
| 61 | + |
| 62 | +def get_summary_relevance_scalar_prompt(query: str, summary: dict): |
| 63 | + import json |
| 64 | + from datetime import datetime |
| 65 | + current_time = datetime.now().strftime("%Y-%m-%d") |
| 66 | + prompt = f"""Analyze these search results and provide a number |
| 67 | +between 0 and 100 according to its relevance to the users query, |
| 68 | +100 being the most relevant and likley answers the query |
| 69 | +0 being the least relevant and does not answer the query |
| 70 | +
|
| 71 | +IMPORTANT: Evaluate and estimate relevance based on these criteria (in order of importance): |
| 72 | +1. Timeliness - current/recent information as of {current_time} |
| 73 | +2. Direct relevance to query: "{query}" |
| 74 | +
|
| 75 | +Search results to evaluate: |
| 76 | + {json.dumps({'title': summary['title'], |
| 77 | + 'summary': summary['summary']} |
| 78 | + , indent=2)} |
| 79 | +
|
| 80 | +Respond only with a number with in 0 and 100 and nothing else: """ |
| 81 | + return prompt |
| 82 | + |
| 83 | +@guidance |
| 84 | +def relevance_by_regex(llm, query, summaries): |
| 85 | + relevance_prompt = get_summary_relevance_prompt(query, summaries=summaries) |
| 86 | + out = llm + relevance_prompt + '[ ' + gen(regex=r'\d+') + ']' |
| 87 | + return out |
| 88 | + |
| 89 | + |
| 90 | +@guidance |
| 91 | +def relevance_by_selection(llm, query, summaries, selection_name='selected_ixes'): |
| 92 | + #print(f"Got summaries in relevance selection:\n{summaries}") |
| 93 | + relevance_prompt = get_summary_relevance_prompt(query, summaries=summaries) |
| 94 | + out = (llm + relevance_prompt |
| 95 | + + '[ ' + reference_selection(n_total_refs=len(summaries), |
| 96 | + selection_name=selection_name) + ']') |
| 97 | + #print("Output produced") |
| 98 | + return out |
| 99 | + |
| 100 | + |
| 101 | +def get_list_of_int_grammar(name="integers"): |
| 102 | + from pydantic import create_model |
| 103 | + schema = create_model(f"list_of_{name}", **{name: list[int]}) |
| 104 | + #class ListOfString(pydantic.BaseModel): |
| 105 | + # indices: list[str] |
| 106 | + json_list = guidance.json(name=name, schema=schema) |
| 107 | + return json_list |
| 108 | + |
| 109 | + |
| 110 | +def get_list_additional_topics_prompt(query: str) -> str: |
| 111 | + from datetime import datetime |
| 112 | + #datetime.now(tz='EST') |
| 113 | + #t = datetime.now().strftime() |
| 114 | + t = str(datetime.now()) |
| 115 | + prompt = f"""The local time is {t}""" |
| 116 | + prompt += """Given the users query, produce a JSON list of other topics related to their query.\n""" |
| 117 | + prompt += f"""Here is their query: {query}\n""" |
| 118 | + prompt += """Provide a list of JSON strings of related topics: """ |
| 119 | + return prompt |
| 120 | + |
| 121 | + |
| 122 | +def get_list_of_str_grammar(name="strings"): |
| 123 | + schema = create_model(f"list_of_{name}", **{name: list[str]}) |
| 124 | + json_list = guidance.json(name=name, schema=schema) |
| 125 | + return json_list |
| 126 | + |
| 127 | + |
| 128 | +def get_q_and_a_grammar(name='answer'): |
| 129 | + schema = create_model(f"{name}", **{name: str, 'confidence': int}) |
| 130 | + json_qa = guidance.json(name=name, schema=schema) |
| 131 | + return json_qa |
| 132 | + |
| 133 | + |
| 134 | +@guidance |
| 135 | +def select_next(choices): |
| 136 | + pass |
| 137 | + |
| 138 | + |
| 139 | +@guidance |
| 140 | +def relevance_by_json_int_list(llm, query, summaries, name='selected_ixes'): |
| 141 | + relevance_prompt = get_summary_relevance_prompt(query, summaries=summaries) |
| 142 | + #class ListOfIntegers(pydantic.BaseModel): |
| 143 | + # indices: list[int] |
| 144 | + #json_list = guidance.json("selected_ixes", schema=ListOfIntegers) |
| 145 | + #return llm + relevance_prompt + json_list |
| 146 | + return llm + relevance_prompt + get_list_of_int_grammar(name=name) |
| 147 | + |
| 148 | + |
| 149 | +@guidance |
| 150 | +def relevance_scalar(llm, query, summary, name='relevance_magnitude'): |
| 151 | + from pydantic import create_model |
| 152 | + schema = create_model(f"scalar_{name}", **{name: int}) |
| 153 | + relevance_prompt = get_summary_relevance_scalar_prompt(query, summary=summary) |
| 154 | + return llm + relevance_prompt + guidance.json(name=name, schema=schema) |
| 155 | + |
| 156 | + |
| 157 | +@dataclass |
| 158 | +class GuidanceGuide(Serializable): |
| 159 | + model_preset: Optional[str] = 'med' |
| 160 | + |
| 161 | + model_config: Optional[GuidanceLlamaCppConfig] = None |
| 162 | + |
| 163 | + def __post_init__(self): |
| 164 | + if self.model_config is None: |
| 165 | + self.model_config = GuidanceLlamaCppConfig.get_preset(self.model_preset) |
| 166 | + |
| 167 | + @property |
| 168 | + def model(self) -> models.Model: |
| 169 | + return self.model_config.model |
| 170 | + |
| 171 | + def get_relevant_ixes_from_summary(self, user_q: str, |
| 172 | + summaries: list[dict], |
| 173 | + relevance_grammar: callable = relevance_by_selection, |
| 174 | + as_list: bool = True): |
| 175 | + res = self.model + relevance_grammar(user_q, summaries=summaries) |
| 176 | + res = res['selected_ixes'] |
| 177 | + if as_list: |
| 178 | + if relevance_grammar == relevance_by_selection: |
| 179 | + res = json.loads(f"[{res}]") |
| 180 | + else: |
| 181 | + res = json.loads(res['selected_ixes'])['selected_ixes'] |
| 182 | + return res |
| 183 | + |
| 184 | + def filter_to_relevant_summeries(self, user_q:str, |
| 185 | + summaries: list[dict], |
| 186 | + relevance_grammar: callable = relevance_by_selection) -> list[dict]: |
| 187 | + ixes_to_keep = self.get_relevant_ixes_from_summary( |
| 188 | + user_q=user_q, |
| 189 | + summaries=summaries, |
| 190 | + relevance_grammar=relevance_grammar |
| 191 | + ) |
| 192 | + return [summaries[i] for i in ixes_to_keep] |
| 193 | + |
| 194 | + def get_relevance_score(self, user_q:str, summary: dict) -> int: |
| 195 | + res = self.model + relevance_scalar(query=user_q, |
| 196 | + summary=summary) |
| 197 | + res = json.loads(res['relevance_magnitude'])['relevance_magnitude'] |
| 198 | + return int(res) |
| 199 | + |
| 200 | + def expand_topic_grammar(self, user_q: str): |
| 201 | + return (self.model |
| 202 | + + get_list_additional_topics_prompt(query=user_q) |
| 203 | + + get_list_of_str_grammar(name='topics')) |
| 204 | + |
| 205 | + def expand_topics(self, user_q: str, |
| 206 | + as_list: bool = True, |
| 207 | + deduplicate_list: bool = True |
| 208 | + ) -> str | list[str]: |
| 209 | + # ** First ['topics'] access is to guidance to get that prompts raw results |
| 210 | + res = self.expand_topic_grammar(user_q=user_q)['topics'] |
| 211 | + if as_list: |
| 212 | + # ** Next access ['topics'] is to access the value at the 'topics' key to |
| 213 | + # get the list of topics from the deserialized json |
| 214 | + topic_l = json.loads(res)['topics'] |
| 215 | + topic_l = list(set(topic_l)) if deduplicate_list else topic_l |
| 216 | + return topic_l |
| 217 | + else: |
| 218 | + return res |
| 219 | + #return json.loads(res)['topics'] if as_list else res |
| 220 | + |
| 221 | + def answer_query(self, user_q: str, content): |
| 222 | + return |
| 223 | + |
| 224 | + |
0 commit comments