-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrefine.py
More file actions
223 lines (193 loc) · 6.94 KB
/
refine.py
File metadata and controls
223 lines (193 loc) · 6.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import argparse
import json
import logging
import os
import sys
import time
from typing import Any, Dict, List
from imports import CrossEncoder
from pipeline import (
TableSimilarityCache,
get_query_table_scores,
get_table_to_table_scores_attn_threaded_cached,
rerank_by_qt_ttqt_attn,
)
from utils import create_llama_index_documents, transform_retrieved_nodes
def configure_logging(level: str) -> None:
logging.basicConfig(
stream=sys.stdout,
level=getattr(logging, level.upper(), logging.INFO),
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
)
def load_expansion_payload(path: str) -> Dict[str, Any]:
with open(path, "r", encoding="utf-8") as f:
payload = json.load(f)
if "questions" not in payload:
raise KeyError("Expansion file missing 'questions' key produced by run_table_expansion.py")
return payload
def load_table_repository(path: str) -> Dict[str, Any]:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def prune_question(
question_entry: Dict[str, Any],
table_repository: Dict[str, Any],
cross_encoder: CrossEncoder,
similarity_cache: TableSimilarityCache,
top_n: int,
max_workers: int,
) -> Dict[str, Any]:
candidate_ids = question_entry.get("expanded_table_ids") or []
question_text = question_entry.get("question", "")
if not candidate_ids:
logging.warning(
"No candidate tables provided for question %s; returning empty pruning result",
question_entry.get("question_id"),
)
return {
**question_entry,
"pruned_table_ids": [],
"query_table_scores": {},
"table_table_scores": {},
"reranked_tables": [],
}
documents = create_llama_index_documents(table_repository, candidate_ids)
if not documents:
logging.warning(
"Failed to build documents for candidate tables %s; returning candidates as-is",
candidate_ids,
)
return {
**question_entry,
"pruned_table_ids": candidate_ids[:top_n],
"query_table_scores": {},
"table_table_scores": {},
"reranked_tables": [],
}
query_scores = get_query_table_scores(question_text, documents, cross_encoder)
transformed = transform_retrieved_nodes(documents)
table_scores = get_table_to_table_scores_attn_threaded_cached(
transformed,
cross_encoder,
query_scores,
max_workers=max_workers,
similarity_cache=similarity_cache,
)
reranked = rerank_by_qt_ttqt_attn(query_scores, table_scores)
if not reranked:
logging.warning(
"Reranking returned no results for question %s; falling back to query scores",
question_entry.get("question_id"),
)
fallback = sorted(query_scores.items(), key=lambda item: item[1], reverse=True)
reranked = [(table_id, score) for table_id, score in fallback]
top_tables: List[str] = [table_id for table_id, _ in reranked[:top_n]]
detailed_rerank: List[Dict[str, Any]] = [
{"table_id": table_id, "score": float(score)} for table_id, score in reranked
]
return {
**question_entry,
"pruned_table_ids": top_tables,
"query_table_scores": {key: float(value) for key, value in query_scores.items()},
"table_table_scores": {key: float(value) for key, value in table_scores.items()},
"reranked_tables": detailed_rerank,
}
def save_results(output_path: str, payload: Dict[str, Any]) -> None:
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2, ensure_ascii=False)
logging.info("Saved pruning results to %s", output_path)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Refine expanded table sets using attention-based pruning."
)
parser.add_argument(
"--expansion-file",
required=True,
help="JSON file produced by run_table_expansion.py.",
)
parser.add_argument(
"--table-repository",
required=True,
help="Path to the table repository JSON containing schemas and samples.",
)
parser.add_argument(
"--cross-encoder-model",
default="jinaai/jina-reranker-v2-base-multilingual",
help="Sentence Transformers cross-encoder model identifier.",
)
parser.add_argument(
"--top-n",
type=int,
default=5,
help="Number of tables to keep after pruning.",
)
parser.add_argument(
"--max-workers",
type=int,
default=2,
help="Thread pool size for table-table scoring.",
)
parser.add_argument(
"--cache-dir",
default="table_cache",
help="Directory used for persistent table similarity cache.",
)
parser.add_argument(
"--cache-version",
default="v1",
help="Version string applied to cache keys (useful when changing parameters).",
)
parser.add_argument(
"--log-level",
default="INFO",
help="Logging level (DEBUG, INFO, WARNING, ERROR).",
)
parser.add_argument(
"--output-file",
required=True,
help="Destination JSON file for pruning results.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
configure_logging(args.log_level)
logging.info("Loading cross-encoder model %s", args.cross_encoder_model)
cross_encoder = CrossEncoder(
args.cross_encoder_model,
automodel_args={"torch_dtype": "auto"},
trust_remote_code=True,
)
expansion_payload = load_expansion_payload(args.expansion_file)
table_repository = load_table_repository(args.table_repository)
cache = TableSimilarityCache(cache_dir=args.cache_dir, cache_version=args.cache_version)
pruned_questions: List[Dict[str, Any]] = []
start = time.time()
for question in expansion_payload["questions"]:
pruned_questions.append(
prune_question(
question,
table_repository,
cross_encoder,
cache,
alpha=args.alpha,
beta=args.beta,
top_n=args.top_n,
max_workers=args.max_workers,
)
)
elapsed = time.time() - start
summary = {
"expansion_file": args.expansion_file,
"table_repository": args.table_repository,
"cross_encoder_model": args.cross_encoder_model,
"top_n": args.top_n,
"max_workers": args.max_workers,
"cache_dir": args.cache_dir,
"cache_version": args.cache_version,
"pruned_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"processing_time_sec": elapsed,
"questions": pruned_questions,
}
save_results(args.output_file, summary)
if __name__ == "__main__":
main()