diff --git a/docs/en/notes/guide/domain_specific_operators/text2sql_operators.md b/docs/en/notes/guide/domain_specific_operators/text2sql_operators.md index 9e90fd824..a28146bfb 100644 --- a/docs/en/notes/guide/domain_specific_operators/text2sql_operators.md +++ b/docs/en/notes/guide/domain_specific_operators/text2sql_operators.md @@ -8,15 +8,15 @@ permalink: /en/guide/Text2SQL_operators/ ## Overview -Text-to-SQL operators are a specialized set of components designed for data processing and quality enhancement in Text-to-SQL tasks, aiming to: -- Clean and augment existing Text-to-SQL datasets -- Generate high-quality question-answer pairs for each sample, including training prompts and chain-of-thought (CoT) reasoning processes -- Provide multi-dimensional data quality assessment and difficulty grading +Text-to-SQL operators are a specialized collection of operators designed for data processing and quality enhancement in Text-to-SQL tasks, aiming to: +- Clean and augment existing Text-to-SQL datasets. +- Generate high-quality question-answer pairs containing training prompts and chain-of-thought reasoning for each sample. +- Provide multi-dimensional data quality assessment and difficulty grading. -Open-source operator varieties are severely limited. To achieve superior data processing quality and fill the gap in publicly available data synthesis and processing methods, we have meticulously designed and **developed in-house** a new suite of operators. Their labels carry the following meanings: +The variety of open-source operators is quite limited. To achieve better data processing quality and fill the gaps in open-source data synthesis and processing methods, we have meticulously designed and **independently developed** a new set of operators. Their marker meanings are as follows: -- 🚀 **Innovative Development**: Core algorithms are originally developed, either filling existing algorithmic gaps or further enhancing performance beyond current bottlenecks. -- ✨ **Open-Source First**: This operator is integrated into mainstream community frameworks for the first time, enabling broader developer adoption and open sharing. +- 🚀 **Independent Innovation**: Core algorithms are originally developed, filling gaps in existing algorithms or further improving performance, breaking through current bottlenecks. +- ✨ **Open-Source Debut**: This operator is integrated into the mainstream community framework for the first time, facilitating use by more developers and enabling open-source sharing. ## Data Generation Operators @@ -24,7 +24,7 @@ Open-source operator varieties are severely limited. To achieve superior data pr Name - Type + Applicable Type Description Official Repository or Paper @@ -33,33 +33,39 @@ Open-source operator varieties are severely limited. To achieve superior data pr SQLGenerator Data Generation - Generates diverse SQL statements based on database schemas + Generates diverse SQL statements based on database schema. OmniSQL SQLVariationGenerator🚀 Data Augmentation - Generates SQL variants based on SQL statements and database schemas + Generates SQL variants based on SQL statements and database schema. - Text2SQLQuestionGenerator Question Generation - Generates corresponding natural language questions from SQL statements + Generates corresponding natural language questions based on SQL statements. OmniSQL Text2SQLPromptGenerator✨ Prompt Generation - Constructs training prompts containing schema and question information + Constructs training prompts containing schema and questions. - Text2SQLCoTGenerator - Chain-of-Thought Generation - Generates step-by-step reasoning chains for SQL derivation + Reasoning Chain Generation + Generates step-by-step chain-of-thought reasoning for SQL. OmniSQL + + Text2SQLCoTVotingGenerator✨ + Reasoning Chain Selection + Performs execution consistency voting on candidate reasoning processes to select the final CoT. + - + @@ -69,7 +75,7 @@ Open-source operator varieties are severely limited. To achieve superior data pr Name - Type + Applicable Type Description Official Repository or Paper @@ -78,13 +84,13 @@ Open-source operator varieties are severely limited. To achieve superior data pr SQLComponentClassifier Difficulty Assessment - Grades difficulty based on SQL syntax complexity + Performs difficulty grading based on SQL syntax complexity. Spider SQLExecutionClassifier🚀 Execution Difficulty Assessment - Grades difficulty based on model execution success rate + Performs difficulty grading based on model execution success rate. - @@ -96,7 +102,7 @@ Open-source operator varieties are severely limited. To achieve superior data pr Name - Type + Applicable Type Description Official Repository or Paper @@ -105,33 +111,39 @@ Open-source operator varieties are severely limited. To achieve superior data pr SQLExecutionFilter✨ Data Cleaning - Filters out SQL statements that cannot be executed successfully + Filters SQL statements that cannot be executed normally. - - SQLConsistencyFilter✨ + SQLExecutabilityFilter✨ Data Cleaning - Verifies semantic consistency between SQL queries and their corresponding natural language questions + Uses query plans to filter inexecutable SQL statements. + - + + + Text2SQLCorrespondenceFilter✨ + Data Cleaning + Verifies semantic consistency between SQL and problem description. - -## Operator Interface Usage Guide +## Operator Interface Usage Instructions -Specifically, for operators requiring designated storage paths or model invocations, we provide encapsulated **Model Interfaces**, **Storage Object Interfaces**, and **Database Management Interfaces**. These interfaces allow pre-definition of required configurations. +Specifically, for operators requiring specified storage paths or model calls, we provide encapsulated **Model Interfaces**, **Storage Object Interfaces**, and **Database Management Interfaces**. These interfaces allow for pre-defining the required configurations. ### Model Interface Configuration -You can pre-define API parameters for operators using the following methods, supporting both generative and embedding models: +You can pre-define model API parameters for operators in the following way, including generative models and embedding models: ```python from dataflow.serving import APILLMServing_request api_llm_serving = APILLMServing_request( - api_url="your_api_url", # API service endpoint + api_url="your_api_url", # API service URL model_name="model_name", # Model name - max_workers=5 # Maximum concurrent workers + max_workers=5 # Maximum concurrency ) embedding_serving = APILLMServing_request( @@ -143,7 +155,7 @@ embedding_serving = APILLMServing_request( ### Storage Interface Configuration -You can pre-define storage parameters for operators as follows: +You can pre-define storage parameters for operators in the following way: ```python from dataflow.utils.storage import FileStorage @@ -151,14 +163,14 @@ from dataflow.utils.storage import FileStorage storage = FileStorage( first_entry_file_name="your_file_path", # Initial file path cache_path="./cache", # Cache directory - file_name_prefix="dataflow_cache_step", # File name prefix + file_name_prefix="dataflow_cache_step", # Filename prefix cache_type="jsonl", # Cache file type ) ``` ### Database Management Interface Configuration -Since database schema information is required, you can pre-define database management as follows. Within operators, interactions with the database manager enable reading and managing database information: +Since database schema information is required, you can pre-define the database management as follows. In the operators, database information is read and managed by interacting with the database manager: ```python from dataflow.utils.text2sql.database_manager import DatabaseManager @@ -167,22 +179,22 @@ database_manager = DatabaseManager( db_type="your_db_type", # Currently supports SQLite and MySQL config={ "your_db_config_key": "your_db_config_value" - } + } ) ``` -Note that SQLite and MySQL databases require specific configuration formats: +Note: For SQLite and MySQL databases, configuration should be done as follows: ```python -# SQLite Full Example +# Complete SQLite example database_manager = DatabaseManager( db_type="sqlite", config={ - "root_path": "/path/to/your/database/folder" # Directory containing SQLite files + "root_path": "/path/to/your/database/folder" # Directory path containing SQLite files } ) -# MySQL Full Example +# Complete MySQL example database_manager = DatabaseManager( db_type="mysql", config={ @@ -197,7 +209,7 @@ database_manager = DatabaseManager( ### Prompt Template Configuration -Operators support predefined prompt templates, which can be imported and used as follows: +Operators support using predefined prompt templates. You can import and use them as follows: ```python from dataflow.prompts.text2sql import ( @@ -205,14 +217,14 @@ from dataflow.prompts.text2sql import ( SelectSQLGeneratorPrompt, Text2SQLQuestionGeneratorPrompt, Text2SQLPromptGeneratorPrompt, - SQLConsistencyFilterPrompt, + Text2SQLCorrespondenceFilterPrompt, SQLVariationGeneratorPrompt ) ``` -The `llm_serving`, `storage`, `database_manager`, and prompt templates referenced later are the interface objects defined above. Complete usage examples can be found in actual pipeline code. +The `llm_serving`, `storage`, `database_manager`, and prompt template objects used later refer to the interface objects defined here. For complete usage examples, please refer to the actual pipeline code. -For parameter passing: The operator constructor primarily accepts configuration-related parameters, allowing one-time setup for multiple uses; while the `X.run()` function accepts I/O-related `key` parameters. See detailed operator descriptions below for specifics. +Regarding parameters: The constructor of the operator object mainly passes information related to operator configuration, which can be set once and used multiple times. The `X.run()` function passes `key` information related to I/O. Details can be found in the operator description examples below. ## Detailed Operator Descriptions @@ -220,27 +232,29 @@ For parameter passing: The operator constructor primarily accepts configuration- #### 1. SQLGenerator -**Function Description:** Generates diverse SQL statements based on database schemas. -- Generates queries covering various SQL syntaxes and difficulty levels -- Supports complex queries such as JOINs, subqueries, and aggregate functions +**Function Description:** Generates diverse SQL statements based on database schema. +- Generates query statements covering various SQL syntax and difficulty levels. +- Supports complex queries such as JOINs, subqueries, aggregate functions, etc. **Input Parameters:** - `__init__()` - - `llm_serving`: LLM service interface for SQL generation - - `database_manager`: Database manager for accessing schema information - - `generate_num`: Number of SQL statements to generate per database - - `prompt_template`: Prompt template for SQL generation + - `llm_serving`: LLM service interface for SQL generation. + - `database_manager`: Database manager for accessing database schema. + - `generate_num`: Number of SQL statements to generate per database. + - `prompt_template`: Prompt template for SQL generation. - `run()` - - `output_sql_key`: Output field name for SQL statements, default "SQL" - - `output_db_id_key`: Output field name for database ID, default "db_id" + - `output_sql_key`: Output SQL statement field name, default "SQL". + - `output_db_id_key`: Output database ID field name, default "db_id". + - `output_sql_complexity_key`: Output SQL complexity field name, default "sql_complexity_type". **Key Features:** -- Intelligent schema analysis and SQL template generation -- Supports multiple database types (SQLite, MySQL) -- Automatically handles table relationships and foreign key constraints -- Generates SQL across varying difficulty levels + +- Intelligent schema analysis and SQL template generation. +- Supports multiple database types (SQLite, MySQL). +- Automatically handles table relationships and foreign key constraints. +- Generates SQL covering different difficulty levels. **Usage Example:** @@ -248,7 +262,6 @@ For parameter passing: The operator constructor primarily accepts configuration- from dataflow.prompts.text2sql import SelectSQLGeneratorPrompt sql_generator = SQLGenerator( - llm_serving=llm_serving, database_manager=database_manager, generate_num=50, prompt_template=SelectSQLGeneratorPrompt() @@ -256,33 +269,36 @@ sql_generator = SQLGenerator( sql_generator.run( storage=storage.step(), output_sql_key="SQL", - output_db_id_key="db_id" + output_db_id_key="db_id", + output_sql_complexity_key="sql_complexity_type" ) ``` #### 2. SQLVariationGenerator🚀 -**Function Description:** Generates SQL variants based on SQL statements and database schemas. -- Increases syntactic diversity -- Supports alias substitution, subquery transformation, JOIN rewriting, etc. -- Effectively expands training data diversity +**Function Description:** Generates SQL statement variants based on SQL statements and database schema. +- Increases syntactic diversity. +- Supports alias replacement, subquery transformation, JOIN rewriting, etc. +- Effectively expands the diversity of training data. **Input Parameters:** - `__init__()` - - `llm_serving`: LLM service interface for variant generation - - `database_manager`: Database manager for validating variant correctness - - `num_variations`: Number of variants to generate per SQL, default 5 - - `prompt_template`: Prompt template for variant generation + - `llm_serving`: LLM service interface for SQL variant generation. + - `database_manager`: Database manager for validating variant correctness. + - `num_variations`: Number of variants to generate per SQL, default 5. + - `prompt_template`: Prompt template for SQL variant generation. - `run()` - - `input_sql_key`: Input field name for SQL statements, default "SQL" - - `input_db_id_key`: Input field name for database ID, default "db_id" + - `input_sql_key`: SQL statement field name, default "SQL". + - `input_db_id_key`: Database ID field name, default "db_id". + - `output_sql_variation_type_key`: Output SQL variant type field name, default "sql_variation_type". **Key Features:** -- Intelligent SQL variant generation -- Covers multiple variation directions to ensure diversity -- Supports multiple expression styles for complex queries + +- Intelligent SQL variant generation. +- Covers multiple variant directions to ensure SQL statement diversity. +- Supports various expressions for complex queries. **Usage Example:** @@ -290,7 +306,6 @@ sql_generator.run( from dataflow.prompts.text2sql import SQLVariationGeneratorPrompt sql_variation_generator = SQLVariationGenerator( - llm_serving=llm_serving, database_manager=database_manager, num_variations=5, prompt_template=SQLVariationGeneratorPrompt() @@ -298,37 +313,41 @@ sql_variation_generator = SQLVariationGenerator( sql_variation_generator.run( storage=storage.step(), input_sql_key="SQL", - input_db_id_key="db_id" + input_db_id_key="db_id", + output_sql_variation_type_key="sql_variation_type" ) ``` #### 3. Text2SQLQuestionGenerator -**Function Description:** Generates natural language questions corresponding to SQL statements. -- Analyzes SQL semantics to generate reasonable questions -- Uses embedding techniques to select optimal question candidates -- Ensures consistency between questions and SQL intent -- Supports multiple question expression styles +**Function Description:** Generates corresponding natural language questions based on SQL statements. +- Analyzes SQL semantics to generate reasonable natural language questions. +- Uses embedding technology to select the optimal question candidate. +- Ensures consistency between the question and the SQL query intent. +- Supports multiple question expression styles. **Input Parameters:** - `__init__()` - - `llm_serving`: LLM service interface for question generation - - `embedding_serving`: Embedding model interface for candidate selection - - `database_manager`: Database manager for schema information - - `question_candidates_num`: Number of candidate questions, default 5 - - `prompt_template`: Prompt template for question generation + - `llm_serving`: LLM service interface for question generation. + - `embedding_serving`: Embedding model interface for question selection. + - `database_manager`: Database manager for obtaining schema information. + - `question_candidates_num`: Number of question candidates, default 5. + - `prompt_template`: Prompt template for question generation. - `run()` - - `input_sql_key`: Input field name for SQL statements, default "SQL" - - `input_db_id_key`: Input field name for database ID, default "db_id" - - `output_question_key`: Output field name for generated questions, default "question" + - `input_sql_key`: SQL statement field name, default "SQL". + - `input_db_id_key`: Database ID field name, default "db_id". + - `output_question_key`: Output question field name, default "question". + - `output_evidence_key`: Output evidence field name, default "evidence". **Key Features:** -- Semantics-aware intelligent question generation -- Multi-candidate generation with optimal selection -- Contextual understanding leveraging database schema -- Ensures naturalness and accuracy of questions + +- Intelligent question generation based on SQL semantics. +- Multi-candidate question generation and optimal selection. +- Contextual understanding combining database schema. +- Ensures naturalness and accuracy of questions. +- Automatically supplements the `question_type` field. **Usage Example:** @@ -336,7 +355,6 @@ sql_variation_generator.run( from dataflow.prompts.text2sql import Text2SQLQuestionGeneratorPrompt text2sql_question_generator = Text2SQLQuestionGenerator( - llm_serving=llm_serving, embedding_serving=embedding_serving, database_manager=database_manager, question_candidates_num=5, @@ -346,34 +364,37 @@ text2sql_question_generator.run( storage=storage.step(), input_sql_key="SQL", input_db_id_key="db_id", - output_question_key="question" + output_question_key="question", + output_evidence_key="evidence" ) ``` #### 4. Text2SQLPromptGenerator✨ -**Function Description:** Constructs training prompts containing schema and question information. -- Formats database schema information -- Generates standardized prompts combining schema and question -- Supports multiple schema formats (DDL, formatted schema, etc.) -- Configurable option to include example data +**Function Description:** Constructs training prompts containing schema and questions. +- Formats database schema information. +- Generates standardized training prompts combining questions. +- Supports multiple schema formats (DDL, formatted schema, etc.). +- Configurable to include example data. **Input Parameters:** - `__init__()` - - `database_manager`: Database manager for schema information - - `prompt_template`: Prompt template must include placeholders {schema} and {question} + - `database_manager`: Database manager for obtaining schema information. + - `prompt_template`: Prompt template, must contain `{schema}` and `{question}` placeholders. - `run()` - - `input_question_key`: Input field name for questions, default "question" - - `input_db_id_key`: Input field name for database ID, default "db_id" - - `output_prompt_key`: Output field name for generated prompts, default "prompt" + - `input_question_key`: Question field name, default "question". + - `input_db_id_key`: Database ID field name, default "db_id". + - `input_evidence_key`: Evidence field name, default "evidence". + - `output_prompt_key`: Output prompt field name, default "prompt". **Key Features:** -- Flexible prompt template system -- Support for multiple schema formats -- Automatic schema formatting and optimization -- Supports schema with embedded example data + +- Flexible prompt template system. +- Support for multiple schema formats. +- Automatic schema formatting and optimization. +- Supports schemas containing example data. **Usage Example:** @@ -388,38 +409,40 @@ text2sql_prompt_generator.run( storage=storage.step(), input_question_key="question", input_db_id_key="db_id", + input_evidence_key="evidence", output_prompt_key="prompt" ) ``` #### 5. Text2SQLCoTGenerator -**Function Description:** Generates step-by-step reasoning chains for SQL derivation. -- Produces detailed reasoning steps based on questions and SQL -- Explains the logical process behind SQL construction -- Supports retry mechanisms and quality assurance -- Enhances model reasoning capability and interpretability +**Function Description:** Generates step-by-step chain-of-thought reasoning for SQL. +- Generates detailed reasoning steps based on questions and SQL. +- Explains the logical process of SQL construction. +- Generates multiple candidate reasoning processes (no validation). +- Improves model reasoning ability and explainability. **Input Parameters:** - `__init__()` - - `llm_serving`: LLM service interface for CoT generation - - `database_manager`: Database manager for schema information - - `max_retries`: Maximum retry attempts, default 3 - - `enable_retry`: Whether to enable retry mechanism, default True - - `prompt_template`: Prompt template for CoT generation + - `llm_serving`: LLM service interface for CoT generation. + - `database_manager`: Database manager for obtaining schema information. + - `sampling_num`: Number of candidate reasoning processes to generate, default 3. + - `prompt_template`: Prompt template for CoT generation. - `run()` - - `input_sql_key`: Input field name for SQL statements, default "SQL" - - `input_question_key`: Input field name for questions, default "question" - - `input_db_id_key`: Input field name for database ID, default "db_id" - - `output_cot_key`: Output field name for CoT reasoning, default "cot_reasoning" + - `input_sql_key`: SQL statement field name, default "SQL". + - `input_question_key`: Question field name, default "question". + - `input_db_id_key`: Database ID field name, default "db_id". + - `input_evidence_key`: Evidence field name, default "evidence". + - `output_cot_key`: Output CoT reasoning field name, default "cot_reasoning" (actual output column is `cot_responses`). **Key Features:** -- High-quality reasoning chain generation -- Automated error detection and retry mechanism -- Schema-aware contextual reasoning -- Supports stepwise decomposition of complex queries + +- High-quality reasoning chain generation. +- Multi-candidate reasoning process output (`cot_responses`). +- Contextual reasoning combining schema. +- Supports step-by-step decomposition of complex queries. **Usage Example:** @@ -429,8 +452,7 @@ from dataflow.prompts.text2sql import Text2SQLCotGeneratorPrompt text2sql_cot_generator = Text2SQLCoTGenerator( llm_serving=cot_generation_llm_serving, database_manager=database_manager, - max_retries=3, - enable_retry=True, + sampling_num=3, prompt_template=Text2SQLCotGeneratorPrompt() ) text2sql_cot_generator.run( @@ -438,6 +460,44 @@ text2sql_cot_generator.run( input_sql_key="SQL", input_question_key="question", input_db_id_key="db_id", + input_evidence_key="evidence", + output_cot_key="cot_reasoning" +) +``` + +#### 6. Text2SQLCoTVotingGenerator✨ + +**Function Description:** Performs execution consistency voting on candidate CoTs to select the final reasoning process. +- Extracts SQL from `cot_responses` and executes them. +- Votes based on execution result consistency. +- Outputs the final `cot_reasoning`. + +**Input Parameters:** + +- `__init__()` + - `database_manager`: Database manager for executing SQL and comparing results. + +- `run()` + - `input_cot_responses_key`: Candidate CoT field name, default "cot_responses". + - `input_db_id_key`: Database ID field name, default "db_id". + - `output_cot_key`: Output final CoT field name, default "cot_reasoning". + +**Key Features:** + +- Reliable voting based on execution consistency. +- Automatically handles invalid candidates and ties. +- Generates the final usable reasoning process. + +**Usage Example:** + +```python +text2sql_cot_voter = Text2SQLCoTVotingGenerator( + database_manager=database_manager +) +text2sql_cot_voter.run( + storage=storage.step(), + input_cot_responses_key="cot_responses", + input_db_id_key="db_id", output_cot_key="cot_reasoning" ) ``` @@ -446,27 +506,28 @@ text2sql_cot_generator.run( #### 1. SQLComponentClassifier -**Function Description:** Grades difficulty based on SQL syntax complexity. -- Analyzes syntactic components of SQL statements -- Scores based on JOIN count, subquery depth, aggregate functions, etc. -- Supports custom difficulty thresholds and labels -- Provides a standardized difficulty classification system +**Function Description:** Performs difficulty grading based on SQL syntax complexity. +- Analyzes the complexity of SQL statement syntax components. +- Scores based on number of JOINs, subquery depth, aggregate functions, etc. +- Supports custom difficulty thresholds and labels. +- Provides a standardized difficulty classification system. **Input Parameters:** - `__init__()` - - `difficulty_thresholds`: List of difficulty thresholds, default [2, 4, 6] - - `difficulty_labels`: List of difficulty labels, default ['easy', 'medium', 'hard', 'extra'] + - `difficulty_thresholds`: List of difficulty thresholds, default [2, 4, 6]. + - `difficulty_labels`: List of difficulty labels, default ['easy', 'medium', 'hard', 'extra']. - `run()` - - `input_sql_key`: Input field name for SQL statements, default "SQL" - - `output_difficulty_key`: Output field name for difficulty label, default "sql_component_difficulty" + - `input_sql_key`: SQL statement field name, default "SQL". + - `output_difficulty_key`: Output difficulty label field name, default "sql_component_difficulty". **Key Features:** -- Complexity analysis based on SQL syntax structure -- Configurable thresholds and labels -- Fast batch processing capability -- Standardized evaluation framework + +- Complexity analysis based on SQL syntactic structure. +- Configurable difficulty thresholds and labels. +- Fast batch processing capability. +- Standardized difficulty assessment system. **Usage Example:** @@ -484,32 +545,33 @@ sql_component_classifier.run( #### 2. SQLExecutionClassifier🚀 -**Function Description:** Grades difficulty based on model execution success rate. -- Uses LLM to repeatedly attempt SQL generation for difficulty assessment -- Dynamically adjusts difficulty level based on model success rate -- Provides difficulty evaluation more aligned with real-world applications -- Supports customizable difficulty configurations and generation counts +**Function Description:** Performs difficulty grading based on model execution success rate. +- Uses LLM to attempt SQL generation multiple times to assess difficulty. +- Dynamically adjusts difficulty levels based on model success rate. +- Provides difficulty assessment closer to real-world applications. +- Supports custom difficulty configurations and generation counts. **Input Parameters:** - `__init__()` - - `llm_serving`: LLM service interface for test generation - - `database_manager`: Database manager for SQL execution validation - - `num_generations`: Number of generation attempts, default 10 - - `difficulty_thresholds`: Difficulty thresholds list, default [2, 5, 9] - - `difficulty_labels`: Difficulty labels list, default ['extra', 'hard', 'medium', 'easy'] + - `llm_serving`: LLM service interface for SQL generation testing. + - `database_manager`: Database manager for SQL execution verification. + - `num_generations`: Number of test generations, default 10. + - `difficulty_thresholds`: List of difficulty thresholds, default [2, 5, 9]. + - `difficulty_labels`: List of difficulty labels, default ['extra', 'hard', 'medium', 'easy']. - `run()` - - `input_sql_key`: Input field name for SQL statements, default "SQL" - - `input_db_id_key`: Input field name for database ID, default "db_id" - - `input_prompt_key`: Input field name for prompts, default "prompt" - - `output_difficulty_key`: Output field name for difficulty label, default "sql_execution_difficulty" + - `input_sql_key`: SQL statement field name, default "SQL". + - `input_db_id_key`: Database ID field name, default "db_id". + - `input_prompt_key`: Prompt field name, default "prompt". + - `output_difficulty_key`: Output difficulty label field name, default "sql_execution_difficulty". **Key Features:** -- Difficulty evaluation based on actual model performance -- Dynamic adjustment mechanism -- Statistical analysis from multiple generations -- Difficulty grading more representative of real-world scenarios + +- Difficulty assessment based on actual model performance. +- Dynamic difficulty adjustment mechanism. +- Statistical analysis of multiple generations. +- Difficulty grading more aligned with real-world scenarios. **Usage Example:** @@ -534,25 +596,26 @@ sql_execution_classifier.run( #### 1. SQLExecutionFilter✨ -**Function Description:** Validates SQL executability and syntactic correctness. -- Executes SQL statements in a real database environment -- Detects syntax errors, runtime errors, and logical inconsistencies -- Filters out non-executable SQL statements -- Ensures SQL validity and usability within the dataset +**Function Description:** Verifies the executability and syntactic correctness of SQL statements. +- Executes SQL statements in a real database environment. +- Detects syntax errors, runtime errors, and logical errors. +- Filters SQL statements that cannot be executed normally. +- Ensures the validity and usability of SQL in the dataset. **Input Parameters:** - `__init__()` - - `database_manager`: Database manager for SQL execution and validation + - `database_manager`: Database manager for SQL execution and verification. - `run()` - - `input_sql_key`: Input field name for SQL statements, default "SQL" - - `input_db_id_key`: Input field name for database ID, default "db_id" + - `input_sql_key`: SQL statement field name, default "SQL". + - `input_db_id_key`: Database ID field name, default "db_id". **Key Features:** -- Real-time SQL execution validation -- Automatic filtering of failed executions -- Efficient batch processing + +- Real-time SQL execution verification. +- Automatic filtering of failed SQL statements. +- Efficient batch processing capability. **Usage Example:** @@ -567,46 +630,82 @@ sql_execution_filter.run( ) ``` -#### 2. SQLConsistencyFilter✨ +#### 2. SQLExecutabilityFilter✨ -**Function Description:** Verifies semantic consistency between SQL and natural language questions. -- Uses LLM to determine whether SQL results answer the posed question -- Checks alignment between question semantics and SQL logic -- Filters semantically inconsistent question-SQL pairs -- Enhances dataset quality and reliability +**Function Description:** Filters inexecutable SQL using query plans. +- Generates query plans via database EXPLAIN. +- Judges executability without executing the SQL. +- Filters SQL statements that are inexecutable or invalid. **Input Parameters:** - `__init__()` - - `llm_serving`: LLM service interface for consistency judgment - - `database_manager`: Database manager for SQL execution - - `prompt_template`: Prompt template for consistency checking + - `database_manager`: Database manager for generating query plans. - `run()` - - `input_sql_key`: Input field name for SQL statements, default "SQL" - - `input_db_id_key`: Input field name for database ID, default "db_id" - - `input_question_key`: Input field name for questions, default "question" + - `input_sql_key`: SQL statement field name, default "SQL". + - `input_db_id_key`: Database ID field name, default "db_id". **Key Features:** -- Intelligent semantic consistency checking -- Combines SQL execution results with question semantics -- Automatically filters mismatched pairs -- Supports consistency verification for complex queries + +- Fast filtering without SQL execution. +- Lower resource consumption and higher throughput. +- Can be combined with execution filters. **Usage Example:** ```python -from dataflow.prompts.text2sql import SQLConsistencyFilterPrompt +sql_executability_filter = SQLExecutabilityFilter( + database_manager=database_manager +) +sql_executability_filter.run( + storage=storage.step(), + input_sql_key="SQL", + input_db_id_key="db_id" +) +``` -sql_consistency_filter = SQLConsistencyFilter( +#### 3. Text2SQLCorrespondenceFilter✨ + +**Function Description:** Verifies the semantic consistency between SQL and problem description. +- Uses LLM to judge whether the SQL answers the question. +- Checks the match between the question and SQL logic. +- Filters semantically inconsistent data pairs. + +**Input Parameters:** + +- `__init__()` + - `llm_serving`: LLM service interface for consistency judgment. + - `database_manager`: Database manager for schema reading. + - `prompt_template`: Prompt template for consistency checking. + +- `run()` + - `input_sql_key`: SQL statement field name, default "SQL". + - `input_db_id_key`: Database ID field name, default "db_id". + - `input_question_key`: Question field name, default "question". + - `input_evidence_key`: Evidence field name, default "evidence". + +**Key Features:** + +- Intelligent semantic consistency checking. +- Consistency judgment incorporating schema. +- Automatic filtering of mismatched data pairs. + +**Usage Example:** + +```python +from dataflow.prompts.text2sql import Text2SQLCorrespondenceFilterPrompt + +text2sql_correspondence_filter = Text2SQLCorrespondenceFilter( llm_serving=llm_serving, database_manager=database_manager, - prompt_template=SQLConsistencyFilterPrompt() + prompt_template=Text2SQLCorrespondenceFilterPrompt() ) -sql_consistency_filter.run( +text2sql_correspondence_filter.run( storage=storage.step(), input_sql_key="SQL", input_db_id_key="db_id", - input_question_key="question" + input_question_key="question", + input_evidence_key="evidence" ) ``` \ No newline at end of file diff --git a/docs/en/notes/guide/pipelines/Text2SqlPipeline.md b/docs/en/notes/guide/pipelines/Text2SqlPipeline.md index 96dce06eb..4a0737c46 100644 --- a/docs/en/notes/guide/pipelines/Text2SqlPipeline.md +++ b/docs/en/notes/guide/pipelines/Text2SqlPipeline.md @@ -1,51 +1,46 @@ --- title: Text-to-SQL Data Synthesis Pipeline icon: material-symbols-light:checkbook-outline-rounded -createTime: 2025/06/17 02:00:31 -permalink: /en/guide/text2sqlpipeline/ ---- - ---- -title: Text-to-SQL Data Synthesis Pipeline -icon: material-symbols-light:checkbook-outline-rounded -createTime: 2025/06/17 02:00:31 -permalink: /zh/guide/text2sqlpipeline/ +createTime: 2025/06/17 02:00:31 +permalink: /en/guide/text2sqlpipeline/ --- # Text-to-SQL Data Synthesis Pipeline ## 1. Overview -The core objective of the **Text-to-SQL Data Synthesis Pipeline** is to generate high-quality Q&A data containing training prompts and chain-of-thought for each sample by cleaning and augmenting existing Text-to-SQL data. This pipeline supports one-click end-to-end processing from raw data to final training data and currently offers the following two data generation modes: +The core objective of the **Text-to-SQL Data Synthesis Pipeline** is to generate high-quality QA data containing training prompts and chain-of-thought for each sample by cleaning and augmenting existing Text-to-SQL data. This pipeline supports one-click processing from raw data to final training data, currently offering the following two data generation modes: ### Supported Application Scenarios - **Data Refinement Mode** - - Filters, augments, and enhances existing data to generate high-quality training data - - Input requirements: Must include the three essential elements: database ID, natural language question, and standard SQL answer + - Filters, expands, and enhances existing data to generate high-quality training data + - Input requirements: Database files, SQL data samples (each containing a database ID and SQL query) - **Data Synthesis Mode** - - Directly generates training data from databases - - Characteristics: No existing data samples required, supports zero-shot startup + - Directly generates training data based on databases + - Input requirements: Database files, no existing SQL data samples required -### Processing Flow +### Processing Pipeline 1. **Data Filtering** - - Execution Filtering: Removes invalid SQL and non-executable SQL statements - - Consistency Filtering: Ensures consistency between the question, SQL, and database schema + - Execution Filtering: Executes SQL statements to eliminate invalid and non-executable SQL + - Executability Filtering: Uses the database to generate query plans (EXPLAIN) for SQL, filtering out non-executable statements without executing them, saving time while achieving filtering + - Consistency Filtering: Ensures consistency among the natural language question, SQL, and database schema 2. **Data Generation** - - SQL Variant Generation: Generates semantically equivalent variants based on existing SQL - - SQL Synthesis: Generates new SQL statements based on the database schema + - SQL Variation Generation: Generates semantically richer variations based on existing SQL + - SQL Synthesis: Generates new SQL statements from scratch based on the database schema - Question Generation: Generates corresponding natural language descriptions based on SQL and schema 3. **Training Data Construction** - Prompt Generation: Integrates natural language questions, database schema, and instruction prompts - - Chain-of-Thought Generation: Constructs step-by-step reasoning processes + - Chain-of-Thought Generation: Constructs a step-by-step reasoning process (Chain-of-Thought), generating multiple reasoning candidate processes + - Reasoning Candidate Voting: Selects the best reasoning process from multiple candidates -4. **Data Grading** - - Syntax Difficulty Grading: Assigns levels based on the complexity of the SQL statement - - Execution Difficulty Grading: Evaluates difficulty based on SQL execution pass rate +4. **Data Difficulty Grading** + - Syntax Difficulty Grading: Grades based on the complexity of the SQL statement + - Execution Difficulty Grading: Assesses difficulty based on SQL execution pass rates ## 2. Quick Start @@ -87,7 +82,7 @@ export DF_API_KEY="sk-xxxxx" $env:DF_API_KEY = "sk-xxxxx" ``` -Configure the API endpoints in `text2sql_pipeline_gen.py` and `text2sql_pipeline_refine.py`: +Configure the API endpoints in `text2sql_pipeline_gen.py` and `text2sql_pipeline_refine.py`. Here, `llm_serving` is the base model for constructing data, and `embedding_serving` is used when generating natural language queries to create multiple queries, vectorize them, and select the best one. ```python self.llm_serving = APILLMServing_request( @@ -96,30 +91,18 @@ self.llm_serving = APILLMServing_request( max_workers=100 ) -cot_generation_api_llm_serving = APILLMServing_request( - api_url="https://api.openai.com/v1/chat/completions", - model_name="gpt-4o", # Optionally use a more powerful model for generating chain-of-thought - max_workers=100 -) - -embedding_serving = APILLMServing_request( +self.embedding_serving = APILLMServing_request( api_url="https://api.openai.com/v1/embeddings", model_name="text-embedding-ada-002", max_workers=100 ) ``` -Service Purpose Description: - -- `llm_serving`: Handles general tasks -- `cot_generation_api_llm_serving`: Generates complex reasoning chains (Chain-of-Thought) -- `embedding_serving`: Generates text embedding vectors - ### Step 5: Configure Database #### Using Example Databases -The pipeline supports automatic download of example databases. When the `db_root_path` parameter is an empty string, the system will automatically download example database files from Hugging Face. +The pipeline supports automatic downloading of example databases. When the `db_root_path` parameter is an empty string, the system will automatically download example database files from Hugging Face. First, configure `HF_TOKEN` (can be obtained from the Hugging Face website): @@ -135,20 +118,20 @@ export HF_TOKEN="hf_xxxxx" $env:HF_TOKEN = "hf_xxxxx" ``` -After configuration, keep the `db_root_path` parameter as an empty string. +After configuration, simply keep the `db_root_path` parameter as an empty string. #### Using Custom Databases -To use custom databases, set the `db_root_path` parameter to the database folder path. Currently supports SQLite and MySQL databases. +To use custom databases, set the `db_root_path` parameter to the path of your database folder. Currently supports SQLite and MySQL databases. ##### SQLite Database Configuration -SQLite is a file-based database system. When using it, you need to specify the path where the database files are stored. +SQLite is a file-based database system. When using it, you need to specify the storage path for the database files. -- **Database Root Directory**: The directory containing all database files +- **Database Root Directory**: Directory containing all database files - This directory should contain multiple database files in `.sqlite` or `.db` format - - The filename of each database file is the `db_id`, in the format `db_id.sqlite` or `db_id.db` - - The database manager supports directory structures with arbitrary nesting levels + - The filename of each database file is the `db_id`, formatted as `db_id.sqlite` or `db_id.db` + - The database manager supports nested directory structures of any depth **Directory Structure Example:** ``` @@ -159,15 +142,15 @@ databases/ **Configuration Example:** ```python -# Automatically download example database +# Automatically download example databases db_root_path = "" model = Text2SQLGeneration_APIPipeline(db_root_path=db_root_path) -# Or manually specify a local database path +# Or manually specify local database path db_root_path = "/path/to/your/database" model = Text2SQLGeneration_APIPipeline(db_root_path=db_root_path) -# Database Manager Configuration +# Database manager configuration database_manager = DatabaseManager( db_type="sqlite", config={ @@ -194,25 +177,23 @@ database_manager = DatabaseManager( ) ``` -> **Note**: Ensure the MySQL service is running and you have access permissions to the respective databases. +> **Note**: Ensure the MySQL service is running and you have access permissions to the relevant database. -### Step 6: Configure SQL Source Files +### Step 6: Configure SQL Source File Choose different pipelines based on your needs: #### 6.1 Data Refinement Pipeline -Input data must contain the following fields: +Input data must contain at least the following fields; other fields can be retained and will not be affected: - **db_id**: Database file name (Database ID) -- **question**: Natural language question - **SQL**: Standard SQL answer **Data Format Example (JSON):** ```json { "db_id": "california_schools", - "question": "What is the highest eligible free rate for K-12 students in the schools in Alameda County?", "SQL": "SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1" } ``` @@ -220,8 +201,8 @@ Input data must contain the following fields: **Storage Configuration:** ```python self.storage = FileStorage( - first_entry_file_name="../example_data/Text2SQLPipeline/pipeline_refine.jsonl", - cache_path="./cache_local", + first_entry_file_name="../example_data/Text2SQLPipeline/pipeline_refine.jsonl", # This can also be replaced with your SQL dataset file path + cache_path="./cache", file_name_prefix="dataflow_cache_step", cache_type="jsonl" ) @@ -229,11 +210,11 @@ self.storage = FileStorage( #### 6.2 Data Synthesis Pipeline -This mode does not require existing data and synthesizes data directly from the database. After configuring the database, set `first_entry_file_name` to an empty string: +This mode does not require existing data; it synthesizes data directly from databases. After configuring the database, set `first_entry_file_name` to an empty string: ```python self.storage = FileStorage( - first_entry_file_name="", + first_entry_file_name="../example_data/Text2SQLPipeline/empty.jsonl", # The data synthesis pipeline does not require original datasets. However, since DataFlow requires file input, an empty jsonl file is introduced as input cache_path="./cache", file_name_prefix="dataflow_cache_step", cache_type="jsonl" @@ -252,23 +233,23 @@ or python api_pipelines/text2sql_pipeline_refine.py ``` -You can choose to run any Pipeline based on your needs; the running method is similar. Subsequent sections will introduce the operators used in the Pipeline and how to configure their parameters. +You can choose to run any pipeline based on your needs; the run methods are similar. The following sections will introduce the operators used in the pipeline and parameter configuration methods. -## 3. Dataflow and Pipeline Logic +## 3. Data Flow and Pipeline Logic ### 3.1 Data Filters #### 3.1.1 **SQL Execution Filter (SQLExecutionFilter)** -The **SQL Execution Filter** (`SQLExecutionFilter`) verifies the correctness of SQL statements by actually executing them, filtering out SQL statements that cannot be executed normally. +The **SQL Execution Filter** (`SQLExecutionFilter`) validates SQL statement correctness by actually executing them, filtering out SQL statements that cannot be executed normally. **Functionality:** -* Verifies the executability of SQL statements +* Validates SQL statement executability * Filters out SQL statements with syntax errors or execution failures -**Input**: SQL statement and database ID -**Output**: Executable SQL statements +**Input**: SQL statement and database ID +**Output**: Normally executable SQL statements; non-executable ones are deleted ```python sql_execution_filter = SQLExecutionFilter( @@ -276,23 +257,40 @@ sql_execution_filter = SQLExecutionFilter( ) ``` -#### 3.1.2 **SQL Consistency Filter (SQLConsistencyFilter)** +#### 3.1.2 **SQL Consistency Filter (Text2SQLCorrespondenceFilter)** -The **SQL Consistency Filter** (`SQLConsistencyFilter`) checks the consistency between the SQL statement, the question, and the database schema, ensuring that the generated SQL correctly answers the corresponding question. +The **SQL Consistency Filter** (`Text2SQLCorrespondenceFilter`) checks consistency between the SQL statement, the question, and the database schema, ensuring the generated SQL correctly answers the corresponding question. **Functionality:** -* Verifies consistency between the SQL statement, the question, and the database schema -* Filters out SQL statements that do not match the question or database schema +* Validates consistency between the SQL statement, the natural language question, and the database schema +* Filters out SQL statements that do not match the natural language question or database schema -**Input**: SQL statement, database ID, and question -**Output**: SQL statements consistent with the question +**Input**: SQL statement, database ID, natural language question, and evidence +**Output**: SQL statements consistent with the natural language question and database schema; inconsistent ones are filtered and deleted ```python -sql_consistency_filter = SQLConsistencyFilter( +text2sql_correspondence_filter = Text2SQLCorrespondenceFilter( llm_serving=llm_serving, database_manager=database_manager, - prompt_template=SQLConsistencyFilterPrompt() + prompt_template=Text2SQLCorrespondenceFilterPrompt() +) +``` + +#### 3.1.3 **SQL Executability Filter (SQLExecutabilityFilter)** + +The **SQL Executability Filter** (`SQLExecutabilityFilter`) uses the database to generate query plans (EXPLAIN) for SQL, filtering out non-executable statements without executing them, saving time while achieving filtering. In the database, being able to generate a query plan indicates executability, so this method can filter out non-executable SQL statements. + +**Functionality:** +* Uses the database to generate query plans for SQL, filtering out non-executable SQL statements +* Does not execute SQL, saving time while achieving filtering + +**Input**: SQL statement and database ID +**Output**: Executable SQL statements; non-executable ones are deleted + +```python +sql_executability_filter = SQLExecutabilityFilter( + database_manager=database_manager ) ``` @@ -300,15 +298,15 @@ sql_consistency_filter = SQLConsistencyFilter( #### 3.2.1 **SQL Generator (SQLGenerator)** -The **SQL Generator** (`SQLGenerator`) is responsible for generating SQL query statements based on the database schema, providing raw SQL data for subsequent data processing flows. +The **SQL Generator** (`SQLGenerator`) is responsible for generating SQL query statements based on the database schema, providing raw SQL data for subsequent data processing workflows. **Functionality:** * Automatically generates SQL query statements based on the database schema * Supports batch generation of a specified number of SQL statements -**Input**: Database schema information -**Output**: Generated SQL statements and corresponding database IDs +**Input**: Database schema information +**Output**: Generated SQL statements, corresponding database ID, and SQL complexity label (`sql_complexity_type`) ```python sql_generator = SQLGenerator( @@ -319,17 +317,17 @@ sql_generator = SQLGenerator( ) ``` -#### 3.2.2 **SQL Variant Generator (SQLVariationGenerator)** +#### 3.2.2 **SQL Variation Generator (SQLVariationGenerator)** -The **SQL Variant Generator** (`SQLVariationGenerator`) generates multiple functionally equivalent variants based on existing SQL statements, enriching the diversity of the dataset. +The **SQL Variation Generator** (`SQLVariationGenerator`) generates multiple functionally equivalent variants based on existing SQL statements, enriching dataset diversity. **Functionality:** -* Generates functionally equivalent SQL variants +* Generates functionally equivalent SQL variations * Increases the diversity and complexity of SQL statements **Input**: Original SQL statement and database ID -**Output**: Collection of SQL variants +**Output**: Set of SQL variations ```python sql_variation_generator = SQLVariationGenerator( @@ -342,15 +340,15 @@ sql_variation_generator = SQLVariationGenerator( #### 3.2.3 **Question Generator (Text2SQLQuestionGenerator)** -The **Question Generator** (`Text2SQLQuestionGenerator`) generates corresponding natural language questions based on given SQL statements, constructing Text-to-SQL question-answer pairs. +The **Question Generator** (`Text2SQLQuestionGenerator`) generates corresponding natural language questions based on given SQL statements, constructing Text-to-SQL QA pairs. **Functionality:** * Generates natural language questions based on SQL statements * Supports generating multiple candidate questions -**Input**: SQL statement and database ID -**Output**: Natural language question +**Input**: SQL statement and database ID +**Output**: Natural language question and evidence (`question` / `evidence`), along with a question type field `question_type` ```python text2sql_question_generator = Text2SQLQuestionGenerator( @@ -364,14 +362,14 @@ text2sql_question_generator = Text2SQLQuestionGenerator( #### 3.2.4 **Prompt Generator (Text2SQLPromptGenerator)** -The **Prompt Generator** (`Text2SQLPromptGenerator`) generates prompt templates for model training based on the question and database schema. +The **Prompt Generator** (`Text2SQLPromptGenerator`) generates prompt templates for model training based on questions and database schema. **Functionality:** * Generates structured prompt templates * Integrates question and database schema information -**Input**: Question and database ID +**Input**: Question, evidence, and database ID **Output**: Formatted prompt template ```python @@ -383,36 +381,53 @@ text2sql_prompt_generator = Text2SQLPromptGenerator( #### 3.2.5 **Chain-of-Thought Generator (Text2SQLCoTGenerator)** -The **Chain-of-Thought Generator** (`Text2SQLCoTGenerator`) generates detailed reasoning processes for SQL queries, helping the model understand the conversion logic from question to SQL. +The **Chain-of-Thought Generator** (`Text2SQLCoTGenerator`) generates multiple detailed reasoning processes for SQL queries, helping the model understand the translation logic from questions to SQL. **Functionality:** * Generates reasoning processes for SQL queries -* Supports retry mechanism to ensure generation quality +* To ensure quality, generates multiple reasoning process candidates (without validation) -**Input**: SQL statement, question, and database ID -**Output**: Chain-of-thought reasoning process +**Input**: SQL statement, question, evidence, and database ID +**Output**: Multiple reasoning process candidates (`cot_responses`), to be used for subsequent voting to select the best reasoning process ```python sql_cot_generator = Text2SQLCoTGenerator( - llm_serving=cot_generation_api_llm_serving, + llm_serving=llm_serving, database_manager=database_manager, - max_retries=3, - enable_retry=True, + sampling_num=3, prompt_template=Text2SQLCotGeneratorPrompt() ) ``` +#### 3.2.6 **Reasoning Process Voter (Text2SQLCoTVotingGenerator)** + +The **Reasoning Process Voter** (`Text2SQLCoTVotingGenerator`) performs execution consistency voting on multiple reasoning process candidates to select the best reasoning process. + +**Functionality:** + +* Performs execution consistency voting on multiple reasoning process candidates +* Selects and outputs the best reasoning process + +**Input**: `cot_responses` and database ID +**Output**: Final reasoning process `cot_reasoning` + +```python +sql_cot_voter = Text2SQLCoTVotingGenerator( + database_manager=database_manager +) +``` + ### 3.3 Data Evaluators #### 3.3.1 **Component Difficulty Evaluator (SQLComponentClassifier)** -The **Component Difficulty Evaluator** (`SQLComponentClassifier`) analyzes the component complexity of SQL statements and labels the difficulty level for data samples. +The **Component Difficulty Evaluator** (`SQLComponentClassifier`) analyzes the component complexity of SQL statements and annotates difficulty levels for data samples. **Functionality:** * Analyzes the component complexity of SQL statements -* Labels difficulty levels for samples +* Annotates difficulty levels for samples **Input**: SQL statement **Output**: SQL component difficulty level @@ -426,7 +441,7 @@ sql_component_classifier = SQLComponentClassifier( #### 3.3.2 **Execution Difficulty Evaluator (SQLExecutionClassifier)** -The **Execution Difficulty Evaluator** (`SQLExecutionClassifier`) evaluates the execution difficulty of SQL queries, making comprehensive judgments based on multiple generation results. +The **Execution Difficulty Evaluator** (`SQLExecutionClassifier`) evaluates the execution difficulty of SQL queries based on comprehensive judgment from multiple generation results. **Functionality:** @@ -448,32 +463,37 @@ sql_execution_classifier = SQLExecutionClassifier( ### 3.4 Prompt Template System -Each component in the pipeline uses specialized prompt template classes to ensure generation quality and consistency: +Each component in the pipeline uses a dedicated prompt template class to ensure generation quality and consistency: - `SelectSQLGeneratorPrompt()` - SQL generation prompts -- `SQLVariationGeneratorPrompt()` - SQL variant generation prompts +- `SQLVariationGeneratorPrompt()` - SQL variation generation prompts - `Text2SQLQuestionGeneratorPrompt()` - Question generation prompts - `Text2SQLPromptGeneratorPrompt()` - Training prompt generation - `Text2SQLCotGeneratorPrompt()` - CoT reasoning generation prompts -- `SQLConsistencyFilterPrompt()` - Consistency filtering prompts +- `Text2SQLCorrespondenceFilterPrompt()` - Consistency filtering prompts ## 4. **Output Data** -- **Format**: `jsonl` (Each step generates a file) -- **Field Description**: +- **Format**: `jsonl` (A file is generated for each step) +- **Field Descriptions**: - `db_id`: Database ID - `question`: Natural language question + - `question_type`: Natural language question type + - `evidence`: Evidence/external knowledge accompanying question generation - `SQL`: Standard SQL answer - - `prompt`: Prompt for training, includes natural language question, database schema, and prompt information - - `cot_reasoning`: Chain-of-thought reasoning data, includes reasoning process and final answer, used for model training - - `sql_component_difficulty`: SQL component complexity assessment - - `sql_execution_difficulty`: SQL execution complexity assessment + - `sql_variation_type`: SQL variation type (only exists in data generated by the SQL refinement pipeline) + - `sql_complexity_type`: SQL complexity type (only exists in data generated by the SQL synthesis pipeline) + - `prompt`: Prompt used for training, containing natural language question, database schema, and prompt information + - `cot_reasoning`: Chain-of-thought data, containing reasoning process and final answer, used for model training + - `sql_component_difficulty`: SQL component difficulty evaluation + - `sql_execution_difficulty`: SQL execution difficulty evaluation - **Example**: ```json { "db_id":"california_schools", "SQL":"SELECT AVG(s.AvgScrRead) AS average_reading_score\nFROM satscores s\nINNER JOIN frpm f ON s.cds = f.CDSCode\nINNER JOIN schools sc ON f.CDSCode = sc.CDSCode\nWHERE s.cname = 'Alameda'\n AND f.\"Charter School (Y\/N)\" = 1\n AND f.\"Charter Funding Type\" = 'Directly funded'\n AND sc.County = 'Alameda';", "question":"What is the average reading score for directly funded charter schools in Alameda County?", + "evidence":"This question focuses on directly funded charter schools in Alameda County.", "prompt":"Task Overview: /* Given the following database schema: ... /* Answer the following: What is the average reading score for directly funded charter schools in Alameda County? * Let's think step by step", "cot_reasoning":"To translate the natural language question into an executable SQLite query, we will follow these steps. ... we can construct the full SQLite query based on these steps:\n\n```sql\nSELECT AVG(s.AvgScrRead) AS average_reading_score\nFROM satscores s\nINNER JOIN frpm f ON s.cds = f.CDSCode\nINNER JOIN schools sc ON f.CDSCode = sc.CDSCode\nWHERE s.cname = 'Alameda'\n AND f.\"Charter School (Y\/N)\" = 1\n AND f.\"Charter Funding Type\" = 'Directly funded'\n AND sc.County = 'Alameda';\n```\n\nThis query follows the logic outlined above and ensures alignment with the reference solution.", "sql_component_difficulty":"medium", @@ -481,25 +501,25 @@ Each component in the pipeline uses specialized prompt template classes to ensur } ``` -## 5. Execution Method +## 5. Running Methods -Two pipelines have been designed here, allowing different configurations to be executed via simple Python commands to meet various data requirements: +Two pipelines are designed here, executed via simple Python commands with different configurations to meet various data needs: * **Data Synthesis Pipeline**: ```bash - python /path/to/text2sql_generation_pipeline.py + python api_pipelines/text2sql_pipeline_gen.py ``` -* **Data Optimization Pipeline**: +* **Data Refinement Pipeline**: ```bash - python /path/to/text2sql_refine_pipeline.py + python api_pipelines/text2sql_pipeline_refine.py ``` -## 6. Pipeline Example +## 6. Pipeline Examples -Below is an example demonstrating how to chain multiple operators for reasoning data processing. This example shows initializing and sequentially executing filtering and cleaning steps. +The following provides example pipelines demonstrating how to use multiple operators for reasoning data processing. These examples show how to initialize a reasoning data processing pipeline and sequentially execute various filtering and cleaning steps. * **Data Synthesis Pipeline**: @@ -509,106 +529,128 @@ class Text2SQLGeneration_APIPipeline(): self.logger = get_logger() self.db_root_path = db_root_path - # Automatic database download if not db_root_path: try: self.db_root_path = download_and_extract_database(self.logger) self.logger.info(f"Using automatically downloaded database at: {self.db_root_path}") except Exception as e: self.logger.error(f"Failed to auto-download database: {e}") - raise + raise else: self.logger.info(f"Using manually specified database path: {self.db_root_path}") + if not os.path.exists(self.db_root_path): + raise FileNotFoundError(f"Database path does not exist: {self.db_root_path}") + self.storage = FileStorage( - first_entry_file_name="", + first_entry_file_name="../example_data/Text2SQLPipeline/empty.jsonl", cache_path="./cache", file_name_prefix="dataflow_cache_step", cache_type="jsonl", ) self.llm_serving = APILLMServing_request( - api_url="http://api.openai.com/v1/chat/completions", + api_url="https://api.openai.com/v1/chat/completions", model_name="gpt-4o", max_workers=100 ) - cot_generation_api_llm_serving = APILLMServing_request( - api_url="http://api.openai.com/v1/chat/completions", - model_name="gpt-4o", - max_workers=100 - ) - - embedding_serving = APILLMServing_request( - api_url="http://api.openai.com/v1/embeddings", + self.embedding_serving = APILLMServing_request( + api_url="https://api.openai.com/v1/embeddings", model_name="text-embedding-ada-002", max_workers=100 ) + # SQLite and MySQL are currently supported + # db_type can be sqlite or mysql, which must match your database type + # If sqlite is selected, root_path must be provided, this path must exist and contain database files + # If mysql is selected, host, user, password must be provided, these credentials must be correct and have access permissions + # MySQL example: + # database_manager = DatabaseManager( + # db_type="mysql", + # config={ + # "host": "localhost", + # "user": "root", + # "password": "your_password", + # "database": "your_database_name" + # } + # ) + # SQLite example: database_manager = DatabaseManager( db_type="sqlite", config={ "root_path": self.db_root_path } ) - + self.sql_generator_step1 = SQLGenerator( llm_serving=self.llm_serving, database_manager=database_manager, - generate_num=50, + generate_num=2, prompt_template=SelectSQLGeneratorPrompt() ) - self.sql_execution_filter_step2 = SQLExecutionFilter( - database_manager=database_manager, + self.sql_executability_filter_step2 = SQLExecutabilityFilter( + database_manager=database_manager ) self.text2sql_question_generator_step3 = Text2SQLQuestionGenerator( llm_serving=self.llm_serving, - embedding_serving=embedding_serving, + embedding_serving=self.embedding_serving, database_manager=database_manager, - question_candidates_num=5, + question_candidates_num=3, prompt_template=Text2SQLQuestionGeneratorPrompt() ) - self.text2sql_prompt_generator_step4 = Text2SQLPromptGenerator( + self.text2sql_correspondence_filter_step4 = Text2SQLCorrespondenceFilter( + llm_serving=self.llm_serving, + database_manager=database_manager, + prompt_template=Text2SQLCorrespondenceFilterPrompt() + ) + + self.text2sql_prompt_generator_step5 = Text2SQLPromptGenerator( database_manager=database_manager, prompt_template=Text2SQLPromptGeneratorPrompt() ) - self.sql_cot_generator_step5 = Text2SQLCoTGenerator( - llm_serving=cot_generation_api_llm_serving, + self.sql_cot_generator_step6 = Text2SQLCoTGenerator( + llm_serving=self.llm_serving, database_manager=database_manager, - max_retries=3, - enable_retry=True, prompt_template=Text2SQLCotGeneratorPrompt() ) - self.sql_component_classifier_step6 = SQLComponentClassifier( + self.sql_cot_voting_generator_step7 = Text2SQLCoTVotingGenerator( + database_manager=database_manager + ) + + self.sql_component_classifier_step8 = SQLComponentClassifier( difficulty_thresholds=[2, 4, 6], difficulty_labels=['easy', 'medium', 'hard', 'extra'] ) - self.sql_execution_classifier_step7 = SQLExecutionClassifier( + self.sql_execution_classifier_step9 = SQLExecutionClassifier( llm_serving=self.llm_serving, database_manager=database_manager, num_generations=10, difficulty_thresholds=[2, 5, 9], difficulty_labels=['extra', 'hard', 'medium', 'easy'] ) - + def forward(self): + sql_key = "SQL" db_id_key = "db_id" question_key = "question" + evidence_key = "evidence" self.sql_generator_step1.run( storage=self.storage.step(), output_sql_key=sql_key, - output_db_id_key=db_id_key + output_db_id_key=db_id_key, + output_sql_complexity_key="sql_complexity_type" ) - self.sql_execution_filter_step2.run( + self.sql_executability_filter_step2.run( storage=self.storage.step(), input_sql_key=sql_key, input_db_id_key=db_id_key @@ -618,31 +660,49 @@ class Text2SQLGeneration_APIPipeline(): storage=self.storage.step(), input_sql_key=sql_key, input_db_id_key=db_id_key, - output_question_key=question_key + output_question_key=question_key, + output_evidence_key=evidence_key ) - self.text2sql_prompt_generator_step4.run( + self.text2sql_correspondence_filter_step4.run( + storage=self.storage.step(), + input_sql_key=sql_key, + input_db_id_key=db_id_key, + input_question_key=question_key, + input_evidence_key=evidence_key + ) + + self.text2sql_prompt_generator_step5.run( storage=self.storage.step(), input_question_key=question_key, input_db_id_key=db_id_key, + input_evidence_key=evidence_key, output_prompt_key="prompt" ) - self.sql_cot_generator_step5.run( + self.sql_cot_generator_step6.run( storage=self.storage.step(), input_sql_key=sql_key, input_question_key=question_key, input_db_id_key=db_id_key, + input_evidence_key=evidence_key, output_cot_key="cot_reasoning" ) - self.sql_component_classifier_step6.run( + self.sql_cot_voting_generator_step7.run( + storage=self.storage.step(), + input_cot_responses_key="cot_responses", + input_db_id_key=db_id_key, + output_cot_key="cot_reasoning" + ) + + self.sql_component_classifier_step8.run( storage=self.storage.step(), input_sql_key=sql_key, output_difficulty_key="sql_component_difficulty" ) - self.sql_execution_classifier_step7.run( + self.sql_execution_classifier_step9.run( storage=self.storage.step(), input_sql_key=sql_key, input_db_id_key=db_id_key, @@ -651,111 +711,125 @@ class Text2SQLGeneration_APIPipeline(): ) if __name__ == "__main__": - # Set db_root_path to your local DB path, or "" to auto-download + # If you have your own database files, you can set the db_root_path to the path of your database files + # If not, please set the db_root_path "", and we will download the example database files automatically db_root_path = "" - + model = Text2SQLGeneration_APIPipeline(db_root_path=db_root_path) model.forward() ``` -* **Data Optimization Pipeline**: +* **Data Refinement Pipeline**: ```python class Text2SQLRefine_APIPipeline(): def __init__(self, db_root_path=""): self.logger = get_logger() self.db_root_path = db_root_path - # Automatic database download if not db_root_path: try: self.db_root_path = download_and_extract_database(self.logger) self.logger.info(f"Using automatically downloaded database at: {self.db_root_path}") except Exception as e: self.logger.error(f"Failed to auto-download database: {e}") - raise + raise else: self.logger.info(f"Using manually specified database path: {self.db_root_path}") + if not os.path.exists(self.db_root_path): + raise FileNotFoundError(f"Database path does not exist: {self.db_root_path}") + self.storage = FileStorage( first_entry_file_name="../example_data/Text2SQLPipeline/pipeline_refine.jsonl", - cache_path="./cache_local", + cache_path="./cache", file_name_prefix="dataflow_cache_step", cache_type="jsonl" ) self.llm_serving = APILLMServing_request( - api_url="http://api.openai.com/v1/chat/completions", + api_url="https://api.openai.com/v1/chat/completions", model_name="gpt-4o", max_workers=100 ) - cot_generation_api_llm_serving = APILLMServing_request( - api_url="http://api.openai.com/v1/chat/completions", - model_name="gpt-4o", - max_workers=100 - ) - - embedding_serving = APILLMServing_request( - api_url="http://api.openai.com/v1/embeddings", + self.embedding_serving = APILLMServing_request( + api_url="https://api.openai.com/v1/embeddings", model_name="text-embedding-ada-002", max_workers=100 ) + # SQLite and MySQL are currently supported + # db_type can be sqlite or mysql, which must match your database type + # If sqlite is selected, root_path must be provided, this path must exist and contain database files + # If mysql is selected, host, user, password must be provided, these credentials must be correct and have access permissions + # MySQL example: + # database_manager = DatabaseManager( + # db_type="mysql", + # config={ + # "host": "localhost", + # "user": "root", + # "password": "your_password", + # "database": "your_database_name" + # } + # ) + # SQLite example: database_manager = DatabaseManager( db_type="sqlite", config={ "root_path": self.db_root_path } ) - - self.sql_execution_filter_step1 = SQLExecutionFilter( - database_manager=database_manager - ) - self.sql_consistency_filter_step2 = SQLConsistencyFilter( - llm_serving=self.llm_serving, - database_manager=database_manager, - prompt_template=SQLConsistencyFilterPrompt() + self.sql_executability_filter_step1 = SQLExecutabilityFilter( + database_manager=database_manager ) - self.sql_variation_generator_step3 = SQLVariationGenerator( + self.sql_variation_generator_step2 = SQLVariationGenerator( llm_serving=self.llm_serving, database_manager=database_manager, - num_variations=5, + num_variations=3, # Number of variations to generate for each SQL prompt_template=SQLVariationGeneratorPrompt() ) - self.sql_execution_filter_step4 = SQLExecutionFilter( + self.sql_executability_filter_step3 = SQLExecutabilityFilter( database_manager=database_manager ) - self.text2sql_question_generator_step5 = Text2SQLQuestionGenerator( + self.text2sql_question_generator_step4 = Text2SQLQuestionGenerator( llm_serving=self.llm_serving, - embedding_serving=embedding_serving, + embedding_serving=self.embedding_serving, database_manager=database_manager, - question_candidates_num=5, + question_candidates_num=3, prompt_template=Text2SQLQuestionGeneratorPrompt() ) + self.text2sql_correspondence_filter_step5 = Text2SQLCorrespondenceFilter( + llm_serving=self.llm_serving, + database_manager=database_manager, + prompt_template=Text2SQLCorrespondenceFilterPrompt() + ) + self.text2sql_prompt_generator_step6 = Text2SQLPromptGenerator( database_manager=database_manager, prompt_template=Text2SQLPromptGeneratorPrompt() ) self.sql_cot_generator_step7 = Text2SQLCoTGenerator( - llm_serving=cot_generation_api_llm_serving, + llm_serving=self.llm_serving, database_manager=database_manager, - max_retries=3, - enable_retry=True, prompt_template=Text2SQLCotGeneratorPrompt() ) - self.sql_component_classifier_step8 = SQLComponentClassifier( + self.sql_cot_voting_generator_step8 = Text2SQLCoTVotingGenerator( + database_manager=database_manager + ) + + self.sql_component_classifier_step9 = SQLComponentClassifier( difficulty_thresholds=[2, 4, 6], difficulty_labels=['easy', 'medium', 'hard', 'extra'] ) - self.sql_execution_classifier_step9 = SQLExecutionClassifier( + self.sql_execution_classifier_step10 = SQLExecutionClassifier( llm_serving=self.llm_serving, database_manager=database_manager, num_generations=10, @@ -763,47 +837,54 @@ class Text2SQLRefine_APIPipeline(): difficulty_labels=['extra', 'hard', 'medium', 'easy'] ) + def forward(self): + sql_key = "SQL" db_id_key = "db_id" question_key = "question" + evidence_key = "evidence" - self.sql_execution_filter_step1.run( + self.sql_executability_filter_step1.run( storage=self.storage.step(), input_sql_key=sql_key, input_db_id_key=db_id_key ) - self.sql_consistency_filter_step2.run( + self.sql_variation_generator_step2.run( storage=self.storage.step(), input_sql_key=sql_key, input_db_id_key=db_id_key, - input_question_key=question_key + output_sql_variation_type_key="sql_variation_type" ) - self.sql_variation_generator_step3.run( + self.sql_executability_filter_step3.run( storage=self.storage.step(), input_sql_key=sql_key, input_db_id_key=db_id_key ) - self.sql_execution_filter_step4.run( + self.text2sql_question_generator_step4.run( storage=self.storage.step(), input_sql_key=sql_key, - input_db_id_key=db_id_key + input_db_id_key=db_id_key, + output_question_key=question_key, + output_evidence_key=evidence_key ) - self.text2sql_question_generator_step5.run( + self.text2sql_correspondence_filter_step5.run( storage=self.storage.step(), input_sql_key=sql_key, input_db_id_key=db_id_key, - output_question_key=question_key + input_question_key=question_key, + input_evidence_key=evidence_key ) self.text2sql_prompt_generator_step6.run( storage=self.storage.step(), input_question_key=question_key, input_db_id_key=db_id_key, + input_evidence_key=evidence_key, output_prompt_key="prompt" ) @@ -812,16 +893,24 @@ class Text2SQLRefine_APIPipeline(): input_sql_key=sql_key, input_question_key=question_key, input_db_id_key=db_id_key, + input_evidence_key=evidence_key, output_cot_key="cot_reasoning" ) - self.sql_component_classifier_step8.run( + self.sql_cot_voting_generator_step8.run( + storage=self.storage.step(), + input_cot_responses_key="cot_responses", + input_db_id_key=db_id_key, + output_cot_key="cot_reasoning" + ) + + self.sql_component_classifier_step9.run( storage=self.storage.step(), input_sql_key=sql_key, output_difficulty_key="sql_component_difficulty" ) - self.sql_execution_classifier_step9.run( + self.sql_execution_classifier_step10.run( storage=self.storage.step(), input_sql_key=sql_key, input_db_id_key=db_id_key, @@ -830,7 +919,8 @@ class Text2SQLRefine_APIPipeline(): ) if __name__ == "__main__": - # Set db_root_path to your local DB path, or "" to auto-download + # If you have your own database files, you can set the db_root_path to the path of your database files + # If not, please set the db_root_path "", and we will download the example database files automatically db_root_path = "" model = Text2SQLRefine_APIPipeline(db_root_path=db_root_path) diff --git a/docs/zh/notes/guide/domain_specific_operators/text2sql_operators.md b/docs/zh/notes/guide/domain_specific_operators/text2sql_operators.md index 76a42611e..ff74dc790 100644 --- a/docs/zh/notes/guide/domain_specific_operators/text2sql_operators.md +++ b/docs/zh/notes/guide/domain_specific_operators/text2sql_operators.md @@ -60,6 +60,12 @@ Text-to-SQL算子是专门用于Text-to-SQL问题数据处理和质量提升的 生成SQL推理的逐步思维链过程 OmniSQL + + Text2SQLCoTVotingGenerator✨ + 推理链筛选 + 对候选推理过程进行执行一致性投票,选出最终CoT + - + @@ -109,7 +115,13 @@ Text-to-SQL算子是专门用于Text-to-SQL问题数据处理和质量提升的 - - SQLConsistencyFilter✨ + SQLExecutabilityFilter✨ + 数据清洗 + 使用查询计划过滤不可执行SQL语句 + - + + + Text2SQLCorrespondenceFilter✨ 数据清洗 验证SQL与问题描述的语义一致性 - @@ -205,7 +217,7 @@ from dataflow.prompts.text2sql import ( SelectSQLGeneratorPrompt, Text2SQLQuestionGeneratorPrompt, Text2SQLPromptGeneratorPrompt, - SQLConsistencyFilterPrompt, + Text2SQLCorrespondenceFilterPrompt, SQLVariationGeneratorPrompt ) ``` @@ -235,6 +247,7 @@ from dataflow.prompts.text2sql import ( - `run()` - `output_sql_key`: 输出SQL语句字段名,默认"SQL" - `output_db_id_key`: 输出数据库ID字段名,默认"db_id" + - `output_sql_complexity_key`: 输出SQL复杂度字段名,默认"sql_complexity_type" **主要特性:** @@ -249,7 +262,6 @@ from dataflow.prompts.text2sql import ( from dataflow.prompts.text2sql import SelectSQLGeneratorPrompt sql_generator = SQLGenerator( - llm_serving=llm_serving, database_manager=database_manager, generate_num=50, prompt_template=SelectSQLGeneratorPrompt() @@ -257,7 +269,8 @@ sql_generator = SQLGenerator( sql_generator.run( storage=storage.step(), output_sql_key="SQL", - output_db_id_key="db_id" + output_db_id_key="db_id", + output_sql_complexity_key="sql_complexity_type" ) ``` @@ -279,6 +292,7 @@ sql_generator.run( - `run()` - `input_sql_key`: SQL语句字段名,默认"SQL" - `input_db_id_key`: 数据库ID字段名,默认"db_id" + - `output_sql_variation_type_key`: 输出SQL变体类型字段名,默认"sql_variation_type" **主要特性:** @@ -292,7 +306,6 @@ sql_generator.run( from dataflow.prompts.text2sql import SQLVariationGeneratorPrompt sql_variation_generator = SQLVariationGenerator( - llm_serving=llm_serving, database_manager=database_manager, num_variations=5, prompt_template=SQLVariationGeneratorPrompt() @@ -300,7 +313,8 @@ sql_variation_generator = SQLVariationGenerator( sql_variation_generator.run( storage=storage.step(), input_sql_key="SQL", - input_db_id_key="db_id" + input_db_id_key="db_id", + output_sql_variation_type_key="sql_variation_type" ) ``` @@ -325,6 +339,7 @@ sql_variation_generator.run( - `input_sql_key`: SQL语句字段名,默认"SQL" - `input_db_id_key`: 数据库ID字段名,默认"db_id" - `output_question_key`: 输出问题字段名,默认"question" + - `output_evidence_key`: 输出证据字段名,默认"evidence" **主要特性:** @@ -332,6 +347,7 @@ sql_variation_generator.run( - 多候选问题生成和最优选择 - 结合数据库Schema的上下文理解 - 确保问题的自然性和准确性 +- 自动补充 `question_type` 问题类型字段 **使用示例:** @@ -339,7 +355,6 @@ sql_variation_generator.run( from dataflow.prompts.text2sql import Text2SQLQuestionGeneratorPrompt text2sql_question_generator = Text2SQLQuestionGenerator( - llm_serving=llm_serving, embedding_serving=embedding_serving, database_manager=database_manager, question_candidates_num=5, @@ -349,7 +364,8 @@ text2sql_question_generator.run( storage=storage.step(), input_sql_key="SQL", input_db_id_key="db_id", - output_question_key="question" + output_question_key="question", + output_evidence_key="evidence" ) ``` @@ -370,6 +386,7 @@ text2sql_question_generator.run( - `run()` - `input_question_key`: 问题字段名,默认"question" - `input_db_id_key`: 数据库ID字段名,默认"db_id" + - `input_evidence_key`: 证据字段名,默认"evidence" - `output_prompt_key`: 输出提示词字段名,默认"prompt" **主要特性:** @@ -392,6 +409,7 @@ text2sql_prompt_generator.run( storage=storage.step(), input_question_key="question", input_db_id_key="db_id", + input_evidence_key="evidence", output_prompt_key="prompt" ) ``` @@ -401,7 +419,7 @@ text2sql_prompt_generator.run( **功能描述:** 生成SQL推理的逐步思维链过程 - 基于问题和SQL生成详细的推理步骤 - 解释SQL构建的逻辑过程 -- 支持错误重试和质量保证 +- 生成多个候选推理过程(不做验证) - 提升模型的推理能力和可解释性 **输入参数:** @@ -409,20 +427,20 @@ text2sql_prompt_generator.run( - `__init__()` - `llm_serving`: LLM服务接口,用于CoT生成 - `database_manager`: 数据库管理器,用于Schema信息获取 - - `max_retries`: 最大重试次数,默认3 - - `enable_retry`: 是否启用重试机制,默认True + - `sampling_num`: 生成候选推理过程数量,默认3 - `prompt_template`: CoT生成的提示词模板 - `run()` - `input_sql_key`: SQL语句字段名,默认"SQL" - `input_question_key`: 问题字段名,默认"question" - `input_db_id_key`: 数据库ID字段名,默认"db_id" - - `output_cot_key`: 输出CoT推理字段名,默认"cot_reasoning" + - `input_evidence_key`: 证据字段名,默认"evidence" + - `output_cot_key`: 输出CoT推理字段名,默认"cot_reasoning"(实际输出列为 `cot_responses`) **主要特性:** - 高质量的推理链生成 -- 自动错误检测和重试机制 +- 多候选推理过程输出(`cot_responses`) - 结合Schema的上下文推理 - 支持复杂查询的逐步分解 @@ -434,8 +452,7 @@ from dataflow.prompts.text2sql import Text2SQLCotGeneratorPrompt text2sql_cot_generator = Text2SQLCoTGenerator( llm_serving=cot_generation_llm_serving, database_manager=database_manager, - max_retries=3, - enable_retry=True, + sampling_num=3, prompt_template=Text2SQLCotGeneratorPrompt() ) text2sql_cot_generator.run( @@ -443,6 +460,44 @@ text2sql_cot_generator.run( input_sql_key="SQL", input_question_key="question", input_db_id_key="db_id", + input_evidence_key="evidence", + output_cot_key="cot_reasoning" +) +``` + +#### 6. Text2SQLCoTVotingGenerator✨ + +**功能描述:** 对候选CoT进行执行一致性投票,选出最终推理过程 +- 从 `cot_responses` 中提取SQL并执行 +- 基于执行结果一致性进行投票 +- 输出最终 `cot_reasoning` + +**输入参数:** + +- `__init__()` + - `database_manager`: 数据库管理器,用于执行SQL并比较结果 + +- `run()` + - `input_cot_responses_key`: 候选CoT字段名,默认"cot_responses" + - `input_db_id_key`: 数据库ID字段名,默认"db_id" + - `output_cot_key`: 输出最终CoT字段名,默认"cot_reasoning" + +**主要特性:** + +- 基于执行一致性的可靠投票 +- 自动处理无效候选与并列情况 +- 生成最终可用的推理过程 + +**使用示例:** + +```python +text2sql_cot_voter = Text2SQLCoTVotingGenerator( + database_manager=database_manager +) +text2sql_cot_voter.run( + storage=storage.step(), + input_cot_responses_key="cot_responses", + input_db_id_key="db_id", output_cot_key="cot_reasoning" ) ``` @@ -575,47 +630,82 @@ sql_execution_filter.run( ) ``` -#### 2. SQLConsistencyFilter✨ +#### 2. SQLExecutabilityFilter✨ + +**功能描述:** 使用查询计划过滤不可执行SQL +- 通过数据库EXPLAIN生成查询计划 +- 不执行SQL即可判断可执行性 +- 过滤无法执行或不合法的SQL语句 + +**输入参数:** + +- `__init__()` + - `database_manager`: 数据库管理器,用于生成查询计划 + +- `run()` + - `input_sql_key`: SQL语句字段名,默认"SQL" + - `input_db_id_key`: 数据库ID字段名,默认"db_id" + +**主要特性:** + +- 不执行SQL的快速过滤 +- 更低的资源消耗与更高的吞吐 +- 可与执行过滤器组合使用 + +**使用示例:** + +```python +sql_executability_filter = SQLExecutabilityFilter( + database_manager=database_manager +) +sql_executability_filter.run( + storage=storage.step(), + input_sql_key="SQL", + input_db_id_key="db_id" +) +``` + +#### 3. Text2SQLCorrespondenceFilter✨ **功能描述:** 验证SQL与问题描述的语义一致性 -- 使用LLM判断SQL执行结果是否回答了问题 +- 使用LLM判断SQL是否回答了问题 - 检查问题与SQL逻辑的匹配度 - 过滤语义不一致的数据对 -- 提升数据集的质量和可靠性 **输入参数:** - `__init__()` - `llm_serving`: LLM服务接口,用于一致性判断 - - `database_manager`: 数据库管理器,用于SQL执行 + - `database_manager`: 数据库管理器,用于Schema读取 - `prompt_template`: 一致性检查的提示词模板 - `run()` - `input_sql_key`: SQL语句字段名,默认"SQL" - `input_db_id_key`: 数据库ID字段名,默认"db_id" - `input_question_key`: 问题字段名,默认"question" + - `input_evidence_key`: 证据字段名,默认"evidence" **主要特性:** - 智能语义一致性检查 -- 结合SQL执行结果和问题语义 +- 结合Schema进行一致性判断 - 自动过滤不匹配的数据对 -- 支持复杂查询的一致性验证 **使用示例:** ```python -from dataflow.prompts.text2sql import SQLConsistencyFilterPrompt +from dataflow.prompts.text2sql import Text2SQLCorrespondenceFilterPrompt -sql_consistency_filter = SQLConsistencyFilter( +text2sql_correspondence_filter = Text2SQLCorrespondenceFilter( llm_serving=llm_serving, database_manager=database_manager, - prompt_template=SQLConsistencyFilterPrompt() + prompt_template=Text2SQLCorrespondenceFilterPrompt() ) -sql_consistency_filter.run( +text2sql_correspondence_filter.run( storage=storage.step(), input_sql_key="SQL", input_db_id_key="db_id", - input_question_key="question" + input_question_key="question", + input_evidence_key="evidence" ) ``` \ No newline at end of file diff --git a/docs/zh/notes/guide/pipelines/Text2SqlPipeline.md b/docs/zh/notes/guide/pipelines/Text2SqlPipeline.md index 42df65d47..e98efe4a5 100644 --- a/docs/zh/notes/guide/pipelines/Text2SqlPipeline.md +++ b/docs/zh/notes/guide/pipelines/Text2SqlPipeline.md @@ -15,26 +15,28 @@ permalink: /zh/guide/text2sqlpipeline/ - **数据优化模式** - 对已有数据进行筛选、扩充和增强,生成高质量训练数据 - - 输入要求:必须包含数据库 ID、自然语言问题和标准 SQL 答案三要素 + - 输入要求:数据库文件,SQL数据样本(每一条数据包含数据库 ID 和 SQL 查询语句) - **数据合成模式** - - 直接从数据库生成训练数据 - - 特点:无需现有数据样本,支持零样本启动 + - 直接基于数据库生成训练数据 + - 输入要求:数据库文件,无需现有SQL数据样本 ### 处理流程 1. **数据过滤** - - 执行过滤:剔除无效 SQL 和无法执行的 SQL 语句 - - 一致性过滤:确保问题、SQL 与数据库 Schema 三者一致 + - 执行过滤:通过执行SQL语句,剔除无效 SQL 和无法执行的 SQL 语句 + - 可执行性过滤:使用数据库对 SQL 生成查询计划,剔除无法执行的 SQL 语句,不执行 SQL,节约时间的同时实现过滤 + - 一致性过滤:确保自然语言问题、SQL 与数据库 Schema 三者一致 2. **数据生成** - - SQL 变体生成:基于现有 SQL 生成语义等效的变体 - - SQL 合成:根据数据库 Schema 生成新的 SQL 语句 + - SQL 变体生成:基于现有 SQL 生成语义更加丰富的变体 + - SQL 合成:根据数据库 Schema 从零开始生成新的 SQL 语句 - 问题生成:基于 SQL 和 Schema 生成对应的自然语言描述 3. **训练数据构建** - 提示词生成:整合自然语言问题、数据库 Schema 和指令提示 - - 思维链生成:构建分步推理过程(Chain-of-Thought) + - 思维链生成:构建分步推理过程(Chain-of-Thought),生成多个推理过程候选 + - 推理过程候选投票:从多个候选中选择最佳的推理过程 4. **数据分级** - 语法难度分级:根据 SQL 语句的复杂度划分等级 @@ -80,7 +82,7 @@ export DF_API_KEY="sk-xxxxx" $env:DF_API_KEY = "sk-xxxxx" ``` -在 `text2sql_pipeline_gen.py` 和 `text2sql_pipeline_refine.py` 中配置 API 端点: +在 `text2sql_pipeline_gen.py` 和 `text2sql_pipeline_refine.py` 中配置 API 端点,其中 `llm_serving` 是构建数据的基础模型, `embedding_serving` 用在生成自然语言查询的时候,生成多个查询然后向量化计算相似度选取最佳的查询。 ```python self.llm_serving = APILLMServing_request( @@ -89,25 +91,13 @@ self.llm_serving = APILLMServing_request( max_workers=100 ) -cot_generation_api_llm_serving = APILLMServing_request( - api_url="https://api.openai.com/v1/chat/completions", - model_name="gpt-4o", # 生成思维链时可选用更强大的模型 - max_workers=100 -) - -embedding_serving = APILLMServing_request( +self.embedding_serving = APILLMServing_request( api_url="https://api.openai.com/v1/embeddings", model_name="text-embedding-ada-002", max_workers=100 ) ``` -各服务用途说明: - -- `llm_serving`:处理通用任务 -- `cot_generation_api_llm_serving`:生成复杂推理链(Chain-of-Thought) -- `embedding_serving`:生成文本嵌入向量 - ### 第五步:配置数据库 #### 使用示例数据库 @@ -195,17 +185,15 @@ database_manager = DatabaseManager( #### 6.1 数据优化流水线 -输入数据需包含以下字段: +输入数据需至少应该包含以下字段,其他字段可以保留,不会有影响: - **db_id**:数据库文件名称(数据库 ID) -- **question**:自然语言问题 - **SQL**:标准 SQL 答案 **数据格式示例(JSON):** ```json { "db_id": "california_schools", - "question": "What is the highest eligible free rate for K-12 students in the schools in Alameda County?", "SQL": "SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1" } ``` @@ -213,8 +201,8 @@ database_manager = DatabaseManager( **存储配置:** ```python self.storage = FileStorage( - first_entry_file_name="../example_data/Text2SQLPipeline/pipeline_refine.jsonl", - cache_path="./cache_local", + first_entry_file_name="../example_data/Text2SQLPipeline/pipeline_refine.jsonl", # 这里也可以替换为你的 SQL 数据集文件路径 + cache_path="./cache", file_name_prefix="dataflow_cache_step", cache_type="jsonl" ) @@ -226,7 +214,7 @@ self.storage = FileStorage( ```python self.storage = FileStorage( - first_entry_file_name="", + first_entry_file_name="../example_data/Text2SQLPipeline/empty.jsonl", # 数据合成流水线不需要原始数据集,但是由于 DataFlow 需要文件输入,引入一个空 jsonl 文件作为输入 cache_path="./cache", file_name_prefix="dataflow_cache_step", cache_type="jsonl" @@ -254,7 +242,7 @@ python api_pipelines/text2sql_pipeline_refine.py ### 3.1 数据过滤器 -#### 3.1.1 **SQL执行过滤器(SQLExecutionFilter)** +#### 3.1.1 ** SQL执行过滤器(SQLExecutionFilter)** **SQL执行过滤器**(`SQLExecutionFilter`)通过实际执行SQL语句来验证其正确性,过滤掉无法正常执行的SQL语句。 @@ -263,8 +251,8 @@ python api_pipelines/text2sql_pipeline_refine.py * 验证SQL语句的可执行性 * 过滤掉语法错误或执行失败的SQL语句 -**输入**:SQL语句和数据库ID -**输出**:可正常执行的SQL语句 +**输入**:SQL语句和数据库ID +**输出**:可正常执行的SQL语句,无法执行的就被删除了 ```python sql_execution_filter = SQLExecutionFilter( @@ -272,23 +260,40 @@ sql_execution_filter = SQLExecutionFilter( ) ``` -#### 3.1.2 **SQL一致性过滤器(SQLConsistencyFilter)** +#### 3.1.2 ** SQL一致性过滤器(Text2SQLCorrespondenceFilter)** -**SQL一致性过滤器**(`SQLConsistencyFilter`)检查SQL语句与问题、数据库Schema之间的一致性,确保生成的SQL能够正确回答对应的问题。 +**SQL一致性过滤器**(`Text2SQLCorrespondenceFilter`)检查SQL语句与问题、数据库Schema之间的一致性,确保生成的SQL能够正确回答对应的问题。 **功能:** -* 验证SQL语句与问题、数据库Schema之间的一致性 -* 过滤掉与问题、数据库Schema不匹配的SQL语句 +* 验证SQL语句与自然语言问题、数据库Schema之间的一致性 +* 过滤掉与自然语言问题、数据库Schema不匹配的SQL语句 -**输入**:SQL语句、数据库ID和问题 -**输出**:与问题一致的SQL语句 +**输入**:SQL语句、数据库ID、自然语言问题和证据 +**输出**:与自然语言问题、数据库Schema一致的SQL语句,不一致的就被过滤删除了 ```python -sql_consistency_filter = SQLConsistencyFilter( +text2sql_correspondence_filter = Text2SQLCorrespondenceFilter( llm_serving=llm_serving, database_manager=database_manager, - prompt_template=SQLConsistencyFilterPrompt() + prompt_template=Text2SQLCorrespondenceFilterPrompt() +) +``` + +#### 3.1.3 ** SQL可执行性过滤器(SQLExecutabilityFilter)** + +**SQL可执行性过滤器**(`SQLExecutabilityFilter`)使用数据库对 SQL 生成查询计划(EXPLAIN),剔除无法执行的 SQL 语句。不执行 SQL,节约时间的同时实现过滤。在数据库中可以生成查询计划代表可执行,因此通过这种方式可以过滤掉无法执行的 SQL 语句。 + +**功能:** +* 使用数据库对 SQL 生成查询计划,剔除无法执行的 SQL 语句 +* 不执行 SQL,节约时间的同时实现过滤 + +**输入**:SQL语句、数据库ID +**输出**:可执行的SQL语句,无法执行的就被删除了 + +```python +sql_executability_filter = SQLExecutabilityFilter( + database_manager=database_manager ) ``` @@ -303,8 +308,8 @@ sql_consistency_filter = SQLConsistencyFilter( * 基于数据库schema自动生成SQL查询语句 * 支持批量生成指定数量的SQL语句 -**输入**:数据库schema信息 -**输出**:生成的SQL语句和对应的数据库ID +**输入**:数据库schema信息 +**输出**:生成的SQL语句、对应的数据库ID与SQL复杂度标签(`sql_complexity_type`) ```python sql_generator = SQLGenerator( @@ -345,8 +350,8 @@ sql_variation_generator = SQLVariationGenerator( * 基于SQL语句生成自然语言问题 * 支持生成多个候选问题 -**输入**:SQL语句和数据库ID -**输出**:自然语言问题 +**输入**:SQL语句和数据库ID +**输出**:自然语言问题与证据(`question` / `evidence`),并附带问题类型字段 `question_type` ```python text2sql_question_generator = Text2SQLQuestionGenerator( @@ -367,7 +372,7 @@ text2sql_question_generator = Text2SQLQuestionGenerator( * 生成结构化的提示模板 * 整合问题和数据库schema信息 -**输入**:问题和数据库ID +**输入**:问题、证据和数据库ID **输出**:格式化的提示模板 ```python @@ -379,26 +384,43 @@ text2sql_prompt_generator = Text2SQLPromptGenerator( #### 3.2.5 **长链推理生成器(Text2SQLCoTGenerator)** -**长链推理生成器**(`Text2SQLCoTGenerator`)为SQL查询生成详细的推理过程,帮助模型理解从问题到SQL的转换逻辑。 +**长链推理生成器**(`Text2SQLCoTGenerator`)为SQL查询生成多个详细的推理过程,帮助模型理解从问题到SQL的转换逻辑。 **功能:** * 生成SQL查询的推理过程 -* 支持重试机制确保生成质量 +* 为了确保质量,生成多个推理过程候选(不做验证) -**输入**:SQL语句、问题和数据库ID -**输出**:思维链推理过程 +**输入**:SQL语句、问题、证据和数据库ID +**输出**:多个推理过程候选(`cot_responses`),用于后续投票选择最佳的推理过程 ```python sql_cot_generator = Text2SQLCoTGenerator( - llm_serving=cot_generation_api_llm_serving, + llm_serving=llm_serving, database_manager=database_manager, - max_retries=3, - enable_retry=True, + sampling_num=3, prompt_template=Text2SQLCotGeneratorPrompt() ) ``` +#### 3.2.6 **推理过程投票器(Text2SQLCoTVotingGenerator)** + +**推理过程投票器**(`Text2SQLCoTVotingGenerator`)对多个推理过程候选进行执行一致性投票,选择最佳的推理过程。 + +**功能:** + +* 对多个推理过程候选进行执行一致性投票 +* 选择最佳的推理过程并输出 + +**输入**:`cot_responses` 与数据库ID +**输出**:最终推理过程 `cot_reasoning` + +```python +sql_cot_voter = Text2SQLCoTVotingGenerator( + database_manager=database_manager +) +``` + ### 3.3 数据评估器 #### 3.3.1 **组件难度评估器(SQLComponentClassifier)** @@ -451,7 +473,7 @@ sql_execution_classifier = SQLExecutionClassifier( - `Text2SQLQuestionGeneratorPrompt()` - 问题生成提示词 - `Text2SQLPromptGeneratorPrompt()` - 训练提示词生成 - `Text2SQLCotGeneratorPrompt()` - CoT推理生成提示词 -- `SQLConsistencyFilterPrompt()` - 一致性过滤提示词 +- `Text2SQLCorrespondenceFilterPrompt()` - 一致性过滤提示词 ## 4. **输出数据** @@ -459,7 +481,11 @@ sql_execution_classifier = SQLExecutionClassifier( - **字段说明**: - `db_id`: 数据库id - `question`: 自然语言问题 + - `question_type`: 自然语言问题类型 +- `evidence`: 问题生成时附带的证据/外部知识 - `SQL`: 标准SQL答案 + - `sql_variation_type`: SQL变体类型(仅在 SQL 优化流水线生成的数据中存在) + - `sql_complexity_type`: SQL复杂度类型(仅在 SQL 合成流水线生成的数据中存在) - `prompt`: 用于训练的提示词,包含自然语言问题、数据库Schema和提示信息 - `cot_reasoning`: 长链推理数据,包含推理过程和最终答案,用于模型训练 - `sql_component_difficulty`: SQL组件复杂度评估 @@ -470,6 +496,7 @@ sql_execution_classifier = SQLExecutionClassifier( "db_id":"california_schools", "SQL":"SELECT AVG(s.AvgScrRead) AS average_reading_score\nFROM satscores s\nINNER JOIN frpm f ON s.cds = f.CDSCode\nINNER JOIN schools sc ON f.CDSCode = sc.CDSCode\nWHERE s.cname = 'Alameda'\n AND f.\"Charter School (Y\/N)\" = 1\n AND f.\"Charter Funding Type\" = 'Directly funded'\n AND sc.County = 'Alameda';", "question":"What is the average reading score for directly funded charter schools in Alameda County?", + "evidence":"This question focuses on directly funded charter schools in Alameda County.", "prompt":"Task Overview: /* Given the following database schema: ... /* Answer the following: What is the average reading score for directly funded charter schools in Alameda County? * Let's think step by step", "cot_reasoning":"To translate the natural language question into an executable SQLite query, we will follow these steps. ... we can construct the full SQLite query based on these steps:\n\n```sql\nSELECT AVG(s.AvgScrRead) AS average_reading_score\nFROM satscores s\nINNER JOIN frpm f ON s.cds = f.CDSCode\nINNER JOIN schools sc ON f.CDSCode = sc.CDSCode\nWHERE s.cname = 'Alameda'\n AND f.\"Charter School (Y\/N)\" = 1\n AND f.\"Charter Funding Type\" = 'Directly funded'\n AND sc.County = 'Alameda';\n```\n\nThis query follows the logic outlined above and ensures alignment with the reference solution.", "sql_component_difficulty":"medium", @@ -484,13 +511,13 @@ sql_execution_classifier = SQLExecutionClassifier( * **数据合成流水线**: ```bash - python /path/to/text2sql_generation_pipeline.py + python api_pipelines/text2sql_pipeline_gen.py ``` * **数据优化流水线**: ```bash - python /path/to/text2sql_refine_pipeline.py + python api_pipelines/text2sql_pipeline_refine.py ``` ## 6. 流水线示例 @@ -505,7 +532,6 @@ class Text2SQLGeneration_APIPipeline(): self.logger = get_logger() self.db_root_path = db_root_path - # 自动下载数据库功能 if not db_root_path: try: self.db_root_path = download_and_extract_database(self.logger) @@ -516,31 +542,43 @@ class Text2SQLGeneration_APIPipeline(): else: self.logger.info(f"Using manually specified database path: {self.db_root_path}") + if not os.path.exists(self.db_root_path): + raise FileNotFoundError(f"Database path does not exist: {self.db_root_path}") + self.storage = FileStorage( - first_entry_file_name="", + first_entry_file_name="../example_data/Text2SQLPipeline/empty.jsonl", cache_path="./cache", file_name_prefix="dataflow_cache_step", cache_type="jsonl", ) self.llm_serving = APILLMServing_request( - api_url="http://api.openai.com/v1/chat/completions", + api_url="https://api.openai.com/v1/chat/completions", model_name="gpt-4o", max_workers=100 ) - cot_generation_api_llm_serving = APILLMServing_request( - api_url="http://api.openai.com/v1/chat/completions", - model_name="gpt-4o", - max_workers=100 - ) - - embedding_serving = APILLMServing_request( - api_url="http://api.openai.com/v1/embeddings", + self.embedding_serving = APILLMServing_request( + api_url="https://api.openai.com/v1/embeddings", model_name="text-embedding-ada-002", max_workers=100 ) + # SQLite and MySQL are currently supported + # db_type can be sqlite or mysql, which must match your database type + # If sqlite is selected, root_path must be provided, this path must exist and contain database files + # If mysql is selected, host, user, password must be provided, these credentials must be correct and have access permissions + # MySQL example: + # database_manager = DatabaseManager( + # db_type="mysql", + # config={ + # "host": "localhost", + # "user": "root", + # "password": "your_password", + # "database": "your_database_name" + # } + # ) + # SQLite example: database_manager = DatabaseManager( db_type="sqlite", config={ @@ -551,41 +589,49 @@ class Text2SQLGeneration_APIPipeline(): self.sql_generator_step1 = SQLGenerator( llm_serving=self.llm_serving, database_manager=database_manager, - generate_num=50, + generate_num=2, prompt_template=SelectSQLGeneratorPrompt() ) - self.sql_execution_filter_step2 = SQLExecutionFilter( - database_manager=database_manager, + self.sql_executability_filter_step2 = SQLExecutabilityFilter( + database_manager=database_manager ) self.text2sql_question_generator_step3 = Text2SQLQuestionGenerator( llm_serving=self.llm_serving, - embedding_serving=embedding_serving, + embedding_serving=self.embedding_serving, database_manager=database_manager, - question_candidates_num=5, + question_candidates_num=3, prompt_template=Text2SQLQuestionGeneratorPrompt() ) - self.text2sql_prompt_generator_step4 = Text2SQLPromptGenerator( + self.text2sql_correspondence_filter_step4 = Text2SQLCorrespondenceFilter( + llm_serving=self.llm_serving, + database_manager=database_manager, + prompt_template=Text2SQLCorrespondenceFilterPrompt() + ) + + self.text2sql_prompt_generator_step5 = Text2SQLPromptGenerator( database_manager=database_manager, prompt_template=Text2SQLPromptGeneratorPrompt() ) - self.sql_cot_generator_step5 = Text2SQLCoTGenerator( - llm_serving=cot_generation_api_llm_serving, + self.sql_cot_generator_step6 = Text2SQLCoTGenerator( + llm_serving=self.llm_serving, database_manager=database_manager, - max_retries=3, - enable_retry=True, prompt_template=Text2SQLCotGeneratorPrompt() ) - self.sql_component_classifier_step6 = SQLComponentClassifier( + self.sql_cot_voting_generator_step7 = Text2SQLCoTVotingGenerator( + database_manager=database_manager + ) + + self.sql_component_classifier_step8 = SQLComponentClassifier( difficulty_thresholds=[2, 4, 6], difficulty_labels=['easy', 'medium', 'hard', 'extra'] ) - self.sql_execution_classifier_step7 = SQLExecutionClassifier( + self.sql_execution_classifier_step9 = SQLExecutionClassifier( llm_serving=self.llm_serving, database_manager=database_manager, num_generations=10, @@ -594,17 +640,20 @@ class Text2SQLGeneration_APIPipeline(): ) def forward(self): + sql_key = "SQL" db_id_key = "db_id" question_key = "question" + evidence_key = "evidence" self.sql_generator_step1.run( storage=self.storage.step(), output_sql_key=sql_key, - output_db_id_key=db_id_key + output_db_id_key=db_id_key, + output_sql_complexity_key="sql_complexity_type" ) - self.sql_execution_filter_step2.run( + self.sql_executability_filter_step2.run( storage=self.storage.step(), input_sql_key=sql_key, input_db_id_key=db_id_key @@ -614,31 +663,49 @@ class Text2SQLGeneration_APIPipeline(): storage=self.storage.step(), input_sql_key=sql_key, input_db_id_key=db_id_key, - output_question_key=question_key + output_question_key=question_key, + output_evidence_key=evidence_key + ) + + self.text2sql_correspondence_filter_step4.run( + storage=self.storage.step(), + input_sql_key=sql_key, + input_db_id_key=db_id_key, + input_question_key=question_key, + input_evidence_key=evidence_key ) - self.text2sql_prompt_generator_step4.run( + self.text2sql_prompt_generator_step5.run( storage=self.storage.step(), input_question_key=question_key, input_db_id_key=db_id_key, + input_evidence_key=evidence_key, output_prompt_key="prompt" ) - self.sql_cot_generator_step5.run( + self.sql_cot_generator_step6.run( storage=self.storage.step(), input_sql_key=sql_key, input_question_key=question_key, input_db_id_key=db_id_key, + input_evidence_key=evidence_key, + output_cot_key="cot_reasoning" + ) + + self.sql_cot_voting_generator_step7.run( + storage=self.storage.step(), + input_cot_responses_key="cot_responses", + input_db_id_key=db_id_key, output_cot_key="cot_reasoning" ) - self.sql_component_classifier_step6.run( + self.sql_component_classifier_step8.run( storage=self.storage.step(), input_sql_key=sql_key, output_difficulty_key="sql_component_difficulty" ) - self.sql_execution_classifier_step7.run( + self.sql_execution_classifier_step9.run( storage=self.storage.step(), input_sql_key=sql_key, input_db_id_key=db_id_key, @@ -647,10 +714,10 @@ class Text2SQLGeneration_APIPipeline(): ) if __name__ == "__main__": - # 如果有自己的数据库文件,可以设置db_root_path为数据库文件路径 - # 如果没有,请设置db_root_path为"",系统将自动下载示例数据库 + # If you have your own database files, you can set the db_root_path to the path of your database files + # If not, please set the db_root_path "", and we will download the example database files automatically db_root_path = "" - + model = Text2SQLGeneration_APIPipeline(db_root_path=db_root_path) model.forward() ``` @@ -662,7 +729,6 @@ class Text2SQLRefine_APIPipeline(): self.logger = get_logger() self.db_root_path = db_root_path - # 自动下载数据库功能 if not db_root_path: try: self.db_root_path = download_and_extract_database(self.logger) @@ -673,31 +739,43 @@ class Text2SQLRefine_APIPipeline(): else: self.logger.info(f"Using manually specified database path: {self.db_root_path}") + if not os.path.exists(self.db_root_path): + raise FileNotFoundError(f"Database path does not exist: {self.db_root_path}") + self.storage = FileStorage( first_entry_file_name="../example_data/Text2SQLPipeline/pipeline_refine.jsonl", - cache_path="./cache_local", + cache_path="./cache", file_name_prefix="dataflow_cache_step", cache_type="jsonl" ) self.llm_serving = APILLMServing_request( - api_url="http://api.openai.com/v1/chat/completions", + api_url="https://api.openai.com/v1/chat/completions", model_name="gpt-4o", max_workers=100 ) - cot_generation_api_llm_serving = APILLMServing_request( - api_url="http://api.openai.com/v1/chat/completions", - model_name="gpt-4o", - max_workers=100 - ) - - embedding_serving = APILLMServing_request( - api_url="http://api.openai.com/v1/embeddings", + self.embedding_serving = APILLMServing_request( + api_url="https://api.openai.com/v1/embeddings", model_name="text-embedding-ada-002", max_workers=100 ) + # SQLite and MySQL are currently supported + # db_type can be sqlite or mysql, which must match your database type + # If sqlite is selected, root_path must be provided, this path must exist and contain database files + # If mysql is selected, host, user, password must be provided, these credentials must be correct and have access permissions + # MySQL example: + # database_manager = DatabaseManager( + # db_type="mysql", + # config={ + # "host": "localhost", + # "user": "root", + # "password": "your_password", + # "database": "your_database_name" + # } + # ) + # SQLite example: database_manager = DatabaseManager( db_type="sqlite", config={ @@ -705,38 +783,111 @@ class Text2SQLRefine_APIPipeline(): } ) - self.sql_execution_filter_step1 = SQLExecutionFilter( + self.sql_executability_filter_step1 = SQLExecutabilityFilter( database_manager=database_manager ) - self.sql_consistency_filter_step2 = SQLConsistencyFilter( + self.sql_variation_generator_step2 = SQLVariationGenerator( llm_serving=self.llm_serving, database_manager=database_manager, - prompt_template=SQLConsistencyFilterPrompt() + num_variations=3, # Number of variations to generate for each SQL + prompt_template=SQLVariationGeneratorPrompt() ) - self.sql_variation_generator_step3 = SQLVariationGenerator( + self.sql_executability_filter_step3 = SQLExecutabilityFilter( + database_manager=database_manager + ) + + self.text2sql_question_generator_step4 = Text2SQLQuestionGenerator( llm_serving=self.llm_serving, + embedding_serving=self.embedding_serving, database_manager=database_manager, - num_variations=5, - prompt_template=SQLVariationGeneratorPrompt() + question_candidates_num=3, + prompt_template=Text2SQLQuestionGeneratorPrompt() + ) + + self.text2sql_correspondence_filter_step5 = Text2SQLCorrespondenceFilter( + llm_serving=self.llm_serving, + database_manager=database_manager, + prompt_template=Text2SQLCorrespondenceFilterPrompt() + ) + + self.text2sql_prompt_generator_step6 = Text2SQLPromptGenerator( + database_manager=database_manager, + prompt_template=Text2SQLPromptGeneratorPrompt() + ) + + self.sql_cot_generator_step7 = Text2SQLCoTGenerator( + llm_serving=self.llm_serving, + database_manager=database_manager, + prompt_template=Text2SQLCotGeneratorPrompt() ) - self.sql_execution_filter_step4 = SQLExecutionFilter( + self.sql_cot_voting_generator_step8 = Text2SQLCoTVotingGenerator( database_manager=database_manager ) - self.text2sql_question_generator_step5.run( + self.sql_component_classifier_step9 = SQLComponentClassifier( + difficulty_thresholds=[2, 4, 6], + difficulty_labels=['easy', 'medium', 'hard', 'extra'] + ) + + self.sql_execution_classifier_step10 = SQLExecutionClassifier( + llm_serving=self.llm_serving, + database_manager=database_manager, + num_generations=10, + difficulty_thresholds=[2, 5, 9], + difficulty_labels=['extra', 'hard', 'medium', 'easy'] + ) + + + def forward(self): + + sql_key = "SQL" + db_id_key = "db_id" + question_key = "question" + evidence_key = "evidence" + + self.sql_executability_filter_step1.run( + storage=self.storage.step(), + input_sql_key=sql_key, + input_db_id_key=db_id_key + ) + + self.sql_variation_generator_step2.run( + storage=self.storage.step(), + input_sql_key=sql_key, + input_db_id_key=db_id_key, + output_sql_variation_type_key="sql_variation_type" + ) + + self.sql_executability_filter_step3.run( + storage=self.storage.step(), + input_sql_key=sql_key, + input_db_id_key=db_id_key + ) + + self.text2sql_question_generator_step4.run( + storage=self.storage.step(), + input_sql_key=sql_key, + input_db_id_key=db_id_key, + output_question_key=question_key, + output_evidence_key=evidence_key + ) + + self.text2sql_correspondence_filter_step5.run( storage=self.storage.step(), input_sql_key=sql_key, input_db_id_key=db_id_key, - output_question_key=question_key + input_question_key=question_key, + input_evidence_key=evidence_key ) self.text2sql_prompt_generator_step6.run( storage=self.storage.step(), input_question_key=question_key, input_db_id_key=db_id_key, + input_evidence_key=evidence_key, output_prompt_key="prompt" ) @@ -745,16 +896,24 @@ class Text2SQLRefine_APIPipeline(): input_sql_key=sql_key, input_question_key=question_key, input_db_id_key=db_id_key, + input_evidence_key=evidence_key, output_cot_key="cot_reasoning" ) - self.sql_component_classifier_step8.run( + self.sql_cot_voting_generator_step8.run( + storage=self.storage.step(), + input_cot_responses_key="cot_responses", + input_db_id_key=db_id_key, + output_cot_key="cot_reasoning" + ) + + self.sql_component_classifier_step9.run( storage=self.storage.step(), input_sql_key=sql_key, output_difficulty_key="sql_component_difficulty" ) - self.sql_execution_classifier_step9.run( + self.sql_execution_classifier_step10.run( storage=self.storage.step(), input_sql_key=sql_key, input_db_id_key=db_id_key, @@ -763,9 +922,10 @@ class Text2SQLRefine_APIPipeline(): ) if __name__ == "__main__": - # 如果有自己的数据库文件,可以设置db_root_path为数据库文件路径 - # 如果没有,请设置db_root_path为"",系统将自动下载示例数据库 + # If you have your own database files, you can set the db_root_path to the path of your database files + # If not, please set the db_root_path "", and we will download the example database files automatically db_root_path = "" model = Text2SQLRefine_APIPipeline(db_root_path=db_root_path) - model.forward() \ No newline at end of file + model.forward() +``` \ No newline at end of file