From 2a035fc0fb3ff303d4f0747400305c116bc77dd9 Mon Sep 17 00:00:00 2001 From: TechNomad-ds <18136835112@163.com> Date: Sat, 28 Jun 2025 11:57:18 +0800 Subject: [PATCH 1/3] change the construction of SchemaLinking operator --- .../generate/Text2SQL/SchemaLinking.py | 863 +++++------------- 1 file changed, 217 insertions(+), 646 deletions(-) diff --git a/dataflow/operators/generate/Text2SQL/SchemaLinking.py b/dataflow/operators/generate/Text2SQL/SchemaLinking.py index 294df85..a64c67b 100644 --- a/dataflow/operators/generate/Text2SQL/SchemaLinking.py +++ b/dataflow/operators/generate/Text2SQL/SchemaLinking.py @@ -1,701 +1,272 @@ from tqdm import tqdm -from transformers import AutoTokenizer, AutoConfig, XLMRobertaXLModel -from transformers.trainer_utils import set_seed -import torch -import torch.nn as nn -import numpy as np -import random -import pandas as pd +import json +import re +from sqlglot.optimizer.qualify import qualify +from sqlglot import parse_one, exp from dataflow.utils.registry import OPERATOR_REGISTRY from dataflow import get_logger from dataflow.core import OperatorABC from dataflow.utils.storage import DataFlowStorage -class SchemaItemClassifier(nn.Module): - def __init__(self, model_name_or_path, mode): - super(SchemaItemClassifier, self).__init__() - if mode in ["eval", "test"]: - config = AutoConfig.from_pretrained(model_name_or_path) - self.plm_encoder = XLMRobertaXLModel(config) - elif mode == "train": - self.plm_encoder = XLMRobertaXLModel.from_pretrained(model_name_or_path) - else: - raise ValueError() - - self.plm_hidden_size = self.plm_encoder.config.hidden_size - - # column cls head - self.column_info_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256) - self.column_info_cls_head_linear2 = nn.Linear(256, 2) - - # column bi-lstm layer - self.column_info_bilstm = nn.LSTM( - input_size = self.plm_hidden_size, - hidden_size = int(self.plm_hidden_size/2), - num_layers = 2, - dropout = 0, - bidirectional = True - ) - - # linear layer after column bi-lstm layer - self.column_info_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size) - - # table cls head - self.table_name_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256) - self.table_name_cls_head_linear2 = nn.Linear(256, 2) - - # table bi-lstm pooling layer - self.table_name_bilstm = nn.LSTM( - input_size = self.plm_hidden_size, - hidden_size = int(self.plm_hidden_size/2), - num_layers = 2, - dropout = 0, - bidirectional = True - ) - # linear layer after table bi-lstm layer - self.table_name_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size) - - # activation function - self.leakyrelu = nn.LeakyReLU() - self.tanh = nn.Tanh() - - # table-column cross-attention layer - self.table_column_cross_attention_layer = nn.MultiheadAttention(embed_dim = self.plm_hidden_size, num_heads = 8) - - # dropout function, p=0.2 means randomly set 20% neurons to 0 - self.dropout = nn.Dropout(p = 0.2) - - def table_column_cross_attention( - self, - table_name_embeddings_in_one_db, - column_info_embeddings_in_one_db, - column_number_in_each_table - ): - table_num = table_name_embeddings_in_one_db.shape[0] - table_name_embedding_attn_list = [] - for table_id in range(table_num): - table_name_embedding = table_name_embeddings_in_one_db[[table_id], :] - column_info_embeddings_in_one_table = column_info_embeddings_in_one_db[ - sum(column_number_in_each_table[:table_id]) : sum(column_number_in_each_table[:table_id+1]), :] - - table_name_embedding_attn, _ = self.table_column_cross_attention_layer( - table_name_embedding, - column_info_embeddings_in_one_table, - column_info_embeddings_in_one_table - ) - - table_name_embedding_attn_list.append(table_name_embedding_attn) - - table_name_embeddings_in_one_db = table_name_embeddings_in_one_db + torch.cat(table_name_embedding_attn_list, dim = 0) - - table_name_embeddings_in_one_db = torch.nn.functional.normalize(table_name_embeddings_in_one_db, p=2.0, dim=1) - - return table_name_embeddings_in_one_db - - def table_column_cls( - self, - encoder_input_ids, - encoder_input_attention_mask, - batch_aligned_column_info_ids, - batch_aligned_table_name_ids, - batch_column_number_in_each_table - ): - batch_size = encoder_input_ids.shape[0] - - encoder_output = self.plm_encoder( - input_ids = encoder_input_ids, - attention_mask = encoder_input_attention_mask, - return_dict = True - ) - - batch_table_name_cls_logits, batch_column_info_cls_logits = [], [] - - for batch_id in range(batch_size): - column_number_in_each_table = batch_column_number_in_each_table[batch_id] - sequence_embeddings = encoder_output["last_hidden_state"][batch_id, :, :] - - aligned_table_name_ids = batch_aligned_table_name_ids[batch_id] - aligned_column_info_ids = batch_aligned_column_info_ids[batch_id] - - table_name_embedding_list, column_info_embedding_list = [], [] - - for table_name_ids in aligned_table_name_ids: - table_name_embeddings = sequence_embeddings[table_name_ids, :] - - output_t, (hidden_state_t, cell_state_t) = self.table_name_bilstm(table_name_embeddings) - table_name_embedding = hidden_state_t[-2:, :].view(1, self.plm_hidden_size) - table_name_embedding_list.append(table_name_embedding) - table_name_embeddings_in_one_db = torch.cat(table_name_embedding_list, dim = 0) - table_name_embeddings_in_one_db = self.leakyrelu(self.table_name_linear_after_pooling(table_name_embeddings_in_one_db)) - - for column_info_ids in aligned_column_info_ids: - column_info_embeddings = sequence_embeddings[column_info_ids, :] - - output_c, (hidden_state_c, cell_state_c) = self.column_info_bilstm(column_info_embeddings) - column_info_embedding = hidden_state_c[-2:, :].view(1, self.plm_hidden_size) - column_info_embedding_list.append(column_info_embedding) - column_info_embeddings_in_one_db = torch.cat(column_info_embedding_list, dim = 0) - column_info_embeddings_in_one_db = self.leakyrelu(self.column_info_linear_after_pooling(column_info_embeddings_in_one_db)) - - table_name_embeddings_in_one_db = self.table_column_cross_attention( - table_name_embeddings_in_one_db, - column_info_embeddings_in_one_db, - column_number_in_each_table - ) - - table_name_embeddings_in_one_db = self.table_name_cls_head_linear1(table_name_embeddings_in_one_db) - table_name_embeddings_in_one_db = self.dropout(self.leakyrelu(table_name_embeddings_in_one_db)) - table_name_cls_logits = self.table_name_cls_head_linear2(table_name_embeddings_in_one_db) - - column_info_embeddings_in_one_db = self.column_info_cls_head_linear1(column_info_embeddings_in_one_db) - column_info_embeddings_in_one_db = self.dropout(self.leakyrelu(column_info_embeddings_in_one_db)) - column_info_cls_logits = self.column_info_cls_head_linear2(column_info_embeddings_in_one_db) - - batch_table_name_cls_logits.append(table_name_cls_logits) - batch_column_info_cls_logits.append(column_info_cls_logits) - - return batch_table_name_cls_logits, batch_column_info_cls_logits - - def forward( - self, - encoder_input_ids, - encoder_attention_mask, - batch_aligned_column_info_ids, - batch_aligned_table_name_ids, - batch_column_number_in_each_table, - ): - batch_table_name_cls_logits, batch_column_info_cls_logits \ - = self.table_column_cls( - encoder_input_ids, - encoder_attention_mask, - batch_aligned_column_info_ids, - batch_aligned_table_name_ids, - batch_column_number_in_each_table - ) - - return { - "batch_table_name_cls_logits" : batch_table_name_cls_logits, - "batch_column_info_cls_logits": batch_column_info_cls_logits - } - -class SchemaItemClassifierInference(): - def __init__(self, model_save_path): - set_seed(42) - self.tokenizer = AutoTokenizer.from_pretrained(model_save_path, add_prefix_space = True) - self.model = SchemaItemClassifier(model_save_path, "test") - self.model.load_state_dict(torch.load(model_save_path + "/dense_classifier.pt", map_location=torch.device('cpu')), strict=False) - if torch.cuda.is_available(): - self.model = self.model.cuda() - self.model.eval() - - def prepare_inputs_and_labels(self, sample): - table_names = [table["table_name"] for table in sample["schema"]["schema_items"]] - column_names = [table["column_names"] for table in sample["schema"]["schema_items"]] - column_num_in_each_table = [len(table["column_names"]) for table in sample["schema"]["schema_items"]] - - column_name_word_indices, table_name_word_indices = [], [] - - input_words = [sample["text"]] - for table_id, table_name in enumerate(table_names): - input_words.append("|") - input_words.append(table_name) - table_name_word_indices.append(len(input_words) - 1) - input_words.append(":") - - for column_name in column_names[table_id]: - input_words.append(column_name) - column_name_word_indices.append(len(input_words) - 1) - input_words.append(",") - - input_words = input_words[:-1] - - tokenized_inputs = self.tokenizer( - input_words, - return_tensors="pt", - is_split_into_words = True, - padding = "max_length", - max_length = 512, - truncation = True - ) - - column_name_token_indices, table_name_token_indices = [], [] - word_indices = tokenized_inputs.word_ids(batch_index = 0) - - for column_name_word_index in column_name_word_indices: - column_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if column_name_word_index == word_index]) - - for table_name_word_index in table_name_word_indices: - table_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if table_name_word_index == word_index]) - - encoder_input_ids = tokenized_inputs["input_ids"] - encoder_input_attention_mask = tokenized_inputs["attention_mask"] - - if torch.cuda.is_available(): - encoder_input_ids = encoder_input_ids.cuda() - encoder_input_attention_mask = encoder_input_attention_mask.cuda() - - return encoder_input_ids, encoder_input_attention_mask, \ - column_name_token_indices, table_name_token_indices, column_num_in_each_table - - def get_schema(self, tables_and_columns): - schema_items = [] - table_names = list(dict.fromkeys([t for t, c in tables_and_columns])) - for table_name in table_names: - schema_items.append( - { - "table_name": table_name, - "column_names": [c for t, c in tables_and_columns if t == table_name] - } - ) - - return {"schema_items": schema_items} - - def get_sequence_length(self, text, tables_and_columns, tokenizer): - table_names = [t for t, c in tables_and_columns] - table_names = list(dict.fromkeys(table_names)) - - column_names = [] - for table_name in table_names: - column_names.append([c for t, c in tables_and_columns if t == table_name]) - - input_words = [text] - for table_id, table_name in enumerate(table_names): - input_words.append("|") - input_words.append(table_name) - input_words.append(":") - for column_name in column_names[table_id]: - input_words.append(column_name) - input_words.append(",") - input_words = input_words[:-1] - - tokenized_inputs = tokenizer(input_words, is_split_into_words = True) - - return len(tokenized_inputs["input_ids"]) - - def split_sample(self, sample, tokenizer): - text = sample["text"] - - table_names = [] - column_names = [] - for table in sample["schema"]["schema_items"]: - table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \ - if table["table_comment"] != "" else table["table_name"]) - column_names.append([column_name + " ( " + column_comment + " ) " \ - if column_comment != "" else column_name \ - for column_name, column_comment in zip(table["column_names"], table["column_comments"])]) - - splitted_samples = [] - recorded_tables_and_columns = [] - - for table_idx, table_name in enumerate(table_names): - for column_name in column_names[table_idx]: - if self.get_sequence_length(text, recorded_tables_and_columns + [[table_name, column_name]], tokenizer) < 500: - recorded_tables_and_columns.append([table_name, column_name]) - else: - splitted_samples.append( - { - "text": text, - "schema": self.get_schema(recorded_tables_and_columns) - } - ) - recorded_tables_and_columns = [[table_name, column_name]] - - splitted_samples.append( - { - "text": text, - "schema": self.get_schema(recorded_tables_and_columns) - } - ) - - return splitted_samples - - def merge_pred_results(self, sample, pred_results): - table_names = [] - column_names = [] - for table in sample["schema"]["schema_items"]: - table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \ - if table["table_comment"] != "" else table["table_name"]) - column_names.append([column_name + " ( " + column_comment + " ) " \ - if column_comment != "" else column_name \ - for column_name, column_comment in zip(table["column_names"], table["column_comments"])]) - - merged_results = [] - for table_id, table_name in enumerate(table_names): - table_prob = 0 - column_probs = [] - for result_dict in pred_results: - if table_name in result_dict: - if table_prob < result_dict[table_name]["table_prob"]: - table_prob = result_dict[table_name]["table_prob"] - column_probs += result_dict[table_name]["column_probs"] - - merged_results.append( - { - "table_name": table_name, - "table_prob": table_prob, - "column_names": column_names[table_id], - "column_probs": column_probs - } - ) - return merged_results - - def lista_contains_listb(self, lista, listb): - for b in listb: - if b not in lista: - return 0 - - return 1 - - def predict_one(self, sample): - encoder_input_ids, encoder_input_attention_mask, column_name_token_indices,\ - table_name_token_indices, column_num_in_each_table = self.prepare_inputs_and_labels(sample) - - with torch.no_grad(): - model_outputs = self.model( - encoder_input_ids, - encoder_input_attention_mask, - [column_name_token_indices], - [table_name_token_indices], - [column_num_in_each_table] - ) - - table_logits = model_outputs["batch_table_name_cls_logits"][0] - table_pred_probs = torch.nn.functional.softmax(table_logits, dim = 1)[:, 1].cpu().tolist() - - column_logits = model_outputs["batch_column_info_cls_logits"][0] - column_pred_probs = torch.nn.functional.softmax(column_logits, dim = 1)[:, 1].cpu().tolist() - - splitted_column_pred_probs = [] - for table_id, column_num in enumerate(column_num_in_each_table): - splitted_column_pred_probs.append(column_pred_probs[sum(column_num_in_each_table[:table_id]): sum(column_num_in_each_table[:table_id]) + column_num]) - column_pred_probs = splitted_column_pred_probs - - result_dict = dict() - for table_idx, table in enumerate(sample["schema"]["schema_items"]): - result_dict[table["table_name"]] = { - "table_name": table["table_name"], - "table_prob": table_pred_probs[table_idx], - "column_names": table["column_names"], - "column_probs": column_pred_probs[table_idx], - } - - return result_dict - - def predict(self, test_sample): - splitted_samples = self.split_sample(test_sample, self.tokenizer) - pred_results = [] - for splitted_sample in splitted_samples: - pred_results.append(self.predict_one(splitted_sample)) - - return self.merge_pred_results(test_sample, pred_results) - - def evaluate_coverage(self, dataset, logger): - max_k = 100 - total_num_for_table_coverage, total_num_for_column_coverage = 0, 0 - table_coverage_results = [0]*max_k - column_coverage_results = [0]*max_k - - for data in dataset: - indices_of_used_tables = [idx for idx, label in enumerate(data["table_labels"]) if label == 1] - pred_results = self.predict(data) - # print(pred_results) - table_probs = [res["table_prob"] for res in pred_results] - for k in range(max_k): - indices_of_top_k_tables = np.argsort(-np.array(table_probs), kind="stable")[:k+1].tolist() - if self.lista_contains_listb(indices_of_top_k_tables, indices_of_used_tables): - table_coverage_results[k] += 1 - total_num_for_table_coverage += 1 - - for table_idx in range(len(data["table_labels"])): - indices_of_used_columns = [idx for idx, label in enumerate(data["column_labels"][table_idx]) if label == 1] - if len(indices_of_used_columns) == 0: - continue - column_probs = pred_results[table_idx]["column_probs"] - for k in range(max_k): - indices_of_top_k_columns = np.argsort(-np.array(column_probs), kind="stable")[:k+1].tolist() - if self.lista_contains_listb(indices_of_top_k_columns, indices_of_used_columns): - column_coverage_results[k] += 1 - - total_num_for_column_coverage += 1 - - logger.info(f"total_num_for_table_coverage:{total_num_for_table_coverage}") - logger.info(f"table_coverage_results:{table_coverage_results}") - logger.info(f"total_num_for_column_coverage:{total_num_for_column_coverage}") - logger.info(f"column_coverage_results:{column_coverage_results}") +SQLITE_RESERVED_KEYWORDS = { + "abort", "action", "add", "after", "all", "alter", "analyze", "and", "as", "asc", "attach", "autoincrement", + "before", "begin", "between", "by", "cascade", "case", "cast", "check", "collate", "column", "commit", "conflict", + "constraint", "create", "cross", "current_date", "current_time", "current_timestamp", "database", "default", + "deferrable", "deferred", "delete", "desc", "detach", "distinct", "drop", "each", "else", "end", "escape", "except", + "exclusive", "exists", "explain", "fail", "for", "foreign", "from", "full", "glob", "group", "having", "if", + "ignore", "immediate", "in", "index", "indexed", "initially", "inner", "insert", "instead", "intersect", "into", + "is", "isnull", "join", "key", "left", "like", "limit", "natural", "no", "not", "notnull", "null", "of", + "offset", "on", "or", "order", "outer", "plan", "pragma", "primary", "query", "raise", "recursive", "references", + "regexp", "reindex", "release", "rename", "replace", "restrict", "right", "rollback", "row", "savepoint", "select", + "set", "table", "temp", "temporary", "then", "to", "trigger", "union", "unique", "update", "using", + "vacuum", "values", "view", "virtual", "when", "where", "with", "without" +} @OPERATOR_REGISTRY.register() class SchemaLinking(OperatorABC): - def __init__(self, table_info_file: str, - model_path: str, - selection_mode: str = "eval", - num_top_k_tables: int = 5, - num_top_k_columns: int = 5 - ): - self.input_table_file = table_info_file - self.model_path = model_path - self.selection_mode = selection_mode - self.num_top_k_tables = num_top_k_tables - self.num_top_k_columns = num_top_k_columns + def __init__(self, table_info_file: str): + self.table_info_file = table_info_file self.logger = get_logger() + self.schema_cache = {} @staticmethod def get_desc(lang): if lang == "zh": return ( - "该算子用于提取出数据库模式链接。\n\n" + "该算子用于通过解析SQL语句提取使用的数据库Schema。\n\n" "输入参数:\n" - "- input_table_file:输入文件路径,数据库表信息\n" + "- table_info_file:tables.jsonl文件路径,包含数据库Schema信息\n" "- input_sql_key:SQL语句键\n" - "- input_table_names_original_key:table file中原始表名字段\n" - "- input_table_names_statement_key:table file中表名说明字段\n" - "- input_column_names_original_key:table file中原始列名字段\n" - "- input_column_names_statement_key:table file中列名说明字段\n" - "- num_top_k_tables:保留的最大表数量\n" - "- num_top_k_columns:每个表保留的最大列数量\n" - "- selection_mode:模型链接模式,eval或train\n" - "- model_path:模式链接模型路径,只在eval模式下需要\n" - "- input_question_key: question key,只在train模式下需要\n" - "- input_dbid_key:db_id key,数据库名,只在train模式下需要\n" - "注意:eval模式需要下载sic_merged模型(quark netdisk or google drive)并在参数中指明模型路径\n\n" + "- input_dbid_key:db_id key,数据库名\n\n" "输出参数:\n" - "- output_key:筛选提取的数据库模式信息,保留的表名和列名" + "- output_used_schema_key:SQL中实际使用的表和列信息,格式为字典,键为表名,值为列名列表" ) elif lang == "en": return ( - "This operator extracts the database schema linking.\n\n" + "This operator extracts used database schema by parsing SQL statements.\n\n" "Input parameters:\n" - "- input_table_file: Input file path, database table information\n" + "- table_info_file: Path to tables.jsonl file containing database schema information\n" "- input_sql_key: SQL statement key\n" - "- input_table_names_original_key: Original table name field in the table file\n" - "- input_table_names_statement_key: Table name description field in the table file\n" - "- input_column_names_original_key: Original column name field in the table file\n" - "- input_column_names_statement_key: Column name description field in the table file\n" - "- num_top_k_tables: Maximum number of tables to retain\n" - "- num_top_k_columns: Maximum number of columns to retain for each table\n" - "- selection_mode: Model linking mode, eval or train\n" - "- model_path: Path to the schema item classifier model, required only in eval mode\n" - "- input_question_key: Question key, required only in train mode\n" - "- input_dbid_key: db_id key, database name, required only in train mode\n" - "Note: In eval mode, you need to download the sic_merged model (from Quark Netdisk or Google Drive) and specify the model path in the parameters.\n\n" + "- input_dbid_key: db_id key, database name\n\n" "Output parameters:\n" - "- output_key: Extracted database schema information, retaining table names and column names." + "- output_used_schema_key: Actually used tables and columns in SQL, formatted as dict with table names as keys and column lists as values" ) else: - return "AnswerExtraction_qwenmatheval performs mathematical answer normalization and standardization." - + return "Schema linking operator for Text2SQL tasks using sqlglot parsing." - def _process_data(self, questions_df, schemas_df): - schemas_dict = {row[self.input_dbid_key]: row for _, row in schemas_df.iterrows()} - - processed_data = [] - - for _, row in questions_df.iterrows(): - db_id = row[self.input_dbid_key] - schema = schemas_dict.get(db_id, {}) + def load_schema_info(self): + if self.schema_cache: + return self.schema_cache - schema_items = [] - - table_names_orig = schema.get(self.input_table_names_original_key, []) - table_names = schema.get(self.input_table_names_statement_key, []) - - for table_idx, (table_orig, table_name) in enumerate(zip(table_names_orig, table_names)): - table_comment = table_name if table_name != table_orig else "" - - columns = [] - column_comments = [] - - for col_info_orig, col_info in zip( - schema.get(self.input_column_names_original_key, []), - schema.get(self.input_column_names_statement_key, []) - ): - if col_info_orig[0] == table_idx: - col_orig = col_info_orig[1] - col_name = col_info[1] + try: + with open(self.table_info_file, 'r', encoding='utf-8') as f: + for line in f: + if line.strip(): + schema_info = json.loads(line.strip()) + db_id = schema_info['db_id'] - col_comment = col_name if col_name != col_orig else "" + schema = {} + table_names_original = schema_info['table_names_original'] + column_names_original = schema_info['column_names_original'] - columns.append(col_orig) - column_comments.append(col_comment) - - schema_items.append({ - "table_name": table_orig, - "table_comment": table_comment, - "column_names": columns, - "column_comments": column_comments - }) + for table_idx, table_name in enumerate(table_names_original): + schema[table_name.lower()] = [] + + for col_info in column_names_original: + table_idx, col_name = col_info + if table_idx >= 0: + table_name = table_names_original[table_idx].lower() + schema[table_name].append(col_name.lower()) + + self.schema_cache[db_id] = schema + + except Exception as e: + self.logger.error(f"Error loading schema info from {self.table_info_file}: {e}") - if self.selection_mode == "train" or self.input_sql_key != "": - processed_data.append({ - "text": row[self.input_question_key], - "sql": row[self.input_sql_key], - "schema": { - "schema_items": schema_items - } - }) - else: - processed_data.append({ - "text": row[self.input_question_key], - "schema": { - "schema_items": schema_items - } - }) - - return processed_data + return self.schema_cache + + def get_schema_for_db(self, db_id): + schema_cache = self.load_schema_info() + return schema_cache.get(db_id, {}) - def find_used_tables_and_columns(self, dataset): - for data in dataset: - sql = data["sql"].lower() - data["table_labels"] = [] - data["column_labels"] = [] + def normalize_sql_column_references(self, sql: str, schema: dict, alias_map: dict) -> str: + col_to_table = {} + all_tables = [] + + for table, cols in schema.items(): + all_tables.append(table) + for col in cols: + if col not in col_to_table: + col_to_table[col] = [] + col_to_table[col].append(table) + + col_fix_map = {} + for col, tables in col_to_table.items(): + if len(tables) == 1: + table = tables[0] + alias = None + for a, t in alias_map.items(): + if t == table: + alias = a + break + if alias: + col_fix_map[col] = f'"{alias}"."{col}"' + + alias_pattern1 = re.compile(r'\bAS\s+"?([a-zA-Z_][\w]*)"?', re.IGNORECASE) + alias_names = set(m.group(1) for m in alias_pattern1.finditer(sql)) + + alias_pattern2 = re.compile(r'\bAS\s+(?:"(?P[^"]+)"|`(?P[^`]+)`)', re.IGNORECASE) + for m in alias_pattern2.finditer(sql): + alias = m.group('dq') or m.group('bq') + alias_names.add(alias) + + def replace_col(m): + col = m.group(0).strip('"') + bef = m.string[max(0, m.start()-10):m.start()] - for table_info in data["schema"]["schema_items"]: - table_name = table_info["table_name"] - data["table_labels"].append(1 if table_name.lower() in sql else 0) - data["column_labels"].append([1 if column_name.lower() in sql else 0 \ - for column_name in table_info["column_names"]]) - return dataset - - def filter_func(self, dataset, dataset_type, sic, num_top_k_tables = 5, num_top_k_columns = 5): - for data in tqdm(dataset, desc = "filtering schema items for the dataset"): - filtered_schema = dict() - filtered_schema["schema_items"] = [] + if ('.' in bef or col in all_tables or col in alias_names or + col in SQLITE_RESERVED_KEYWORDS): + return m.group(0) + + if ((m.group(0).startswith('"') and not m.group(0).endswith('"')) or + (not m.group(0).startswith('"') and m.group(0).endswith('"'))): + return m.group(0) + + return col_fix_map.get(col, m.group(0)) - table_names = [table["table_name"] for table in data["schema"]["schema_items"]] - table_comments = [table["table_comment"] for table in data["schema"]["schema_items"]] - column_names = [table["column_names"] for table in data["schema"]["schema_items"]] - column_comments = [table["column_comments"] for table in data["schema"]["schema_items"]] + if col_fix_map: + pattern = re.compile( + r'(? Date: Sat, 28 Jun 2025 11:58:44 +0800 Subject: [PATCH 2/3] update the setting of schemalinking in test_text2sql.py --- test/test_text2sql.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/test/test_text2sql.py b/test/test_text2sql.py index 93a650b..b491014 100644 --- a/test/test_text2sql.py +++ b/test/test_text2sql.py @@ -33,13 +33,9 @@ def __init__(self,llm_serving=None): self.sql_difficulty_classifier_step2 = SQLDifficultyClassifier() - # self.schema_linking_step3 = SchemaLinking( - # table_info_file=table_info_file, - # model_path="", # download the model from https://pan.quark.cn/s/418c417127ae or https://drive.google.com/file/d/1xzNvv5h-ZjhjOOZ-ePv1xg_n3YbUNLWi/view?usp=sharing - # selection_mode="eval", - # num_top_k_tables=5, - # num_top_k_columns=5 - # ) + self.schema_linking_step3 = SchemaLinking( + table_info_file=table_info_file + ) self.database_schema_extractor_step4 = DatabaseSchemaExtractor( table_info_file=table_info_file, @@ -93,17 +89,12 @@ def forward(self): output_difficulty_key="sql_component_difficulty" ) - # self.schema_linking_step3.run( - # storage=self.storage.step(), - # input_sql_key=input_sql_key, - # input_dbid_key=input_dbid_key, - # input_question_key=input_question_key, - # input_table_names_original_key="table_names_original", - # input_table_names_statement_key="table_names", - # input_column_names_original_key="column_names_original", - # input_column_names_statement_key="column_names", - # output_schema_key="selected_schema" - # ) + self.schema_linking_step3.run( + storage=self.storage.step(), + input_sql_key=input_sql_key, + input_dbid_key=input_dbid_key, + output_used_schema_key="selected_schema" + ) self.database_schema_extractor_step4.run( storage=self.storage.step(), From 9ba0b6a7b529b98ca4ca78b74ad479744d66e77f Mon Sep 17 00:00:00 2001 From: TechNomad-ds <18136835112@163.com> Date: Sun, 29 Jun 2025 12:19:38 +0800 Subject: [PATCH 3/3] delete the demo database path --- test/test_text2sql.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_text2sql.py b/test/test_text2sql.py index b491014..9c246a7 100644 --- a/test/test_text2sql.py +++ b/test/test_text2sql.py @@ -21,7 +21,9 @@ def __init__(self,llm_serving=None): else: api_llm_serving = llm_serving - db_root_path = "../dataflow/example/Text2SQLPipeline/dev_databases" + # Please download the demo database from the following URL: + # https://huggingface.co/datasets/Open-Dataflow/dataflow-Text2SQL-database-example + db_root_path = "" table_info_file = "../dataflow/example/Text2SQLPipeline/dev_tables.jsonl" self.sql_filter_step1 = SQLFilter(