|
1 | 1 | from sqlglot import Tokenizer |
2 | 2 | from sqlmesh.core.context import Context |
3 | | -from sqlmesh.lsp.completions import get_keywords_from_tokenizer, get_sql_completions |
| 3 | +from sqlmesh.lsp.completions import ( |
| 4 | + get_keywords_from_tokenizer, |
| 5 | + get_sql_completions, |
| 6 | + extract_keywords_from_content, |
| 7 | +) |
4 | 8 | from sqlmesh.lsp.context import LSPContext |
5 | 9 | from sqlmesh.lsp.uri import URI |
6 | 10 |
|
@@ -36,3 +40,125 @@ def test_get_sql_completions_with_context_and_file_uri(): |
36 | 40 | completions = lsp_context.get_autocomplete(URI.from_path(file_uri)) |
37 | 41 | assert len(completions.keywords) > len(TOKENIZER_KEYWORDS) |
38 | 42 | assert "sushi.active_customers" not in completions.models |
| 43 | + |
| 44 | + |
| 45 | +def test_extract_keywords_from_content(): |
| 46 | + # Test extracting keywords from SQL content |
| 47 | + content = """ |
| 48 | + SELECT customer_id, order_date, total_amount |
| 49 | + FROM orders o |
| 50 | + JOIN customers c ON o.customer_id = c.id |
| 51 | + WHERE order_date > '2024-01-01' |
| 52 | + """ |
| 53 | + |
| 54 | + keywords = extract_keywords_from_content(content) |
| 55 | + |
| 56 | + # Check that identifiers are extracted |
| 57 | + assert "customer_id" in keywords |
| 58 | + assert "order_date" in keywords |
| 59 | + assert "total_amount" in keywords |
| 60 | + assert "orders" in keywords |
| 61 | + assert "customers" in keywords |
| 62 | + assert "o" in keywords # alias |
| 63 | + assert "c" in keywords # alias |
| 64 | + assert "id" in keywords |
| 65 | + |
| 66 | + # Check that SQL keywords are NOT included |
| 67 | + assert "SELECT" not in keywords |
| 68 | + assert "FROM" not in keywords |
| 69 | + assert "JOIN" not in keywords |
| 70 | + assert "WHERE" not in keywords |
| 71 | + assert "ON" not in keywords |
| 72 | + |
| 73 | + |
| 74 | +def test_get_sql_completions_with_file_content(): |
| 75 | + context = Context(paths=["examples/sushi"]) |
| 76 | + lsp_context = LSPContext(context) |
| 77 | + |
| 78 | + # SQL content with custom identifiers |
| 79 | + content = """ |
| 80 | + SELECT my_custom_column, another_identifier |
| 81 | + FROM my_custom_table mct |
| 82 | + JOIN some_other_table sot ON mct.id = sot.table_id |
| 83 | + WHERE my_custom_column > 100 |
| 84 | + """ |
| 85 | + |
| 86 | + file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql") |
| 87 | + completions = lsp_context.get_autocomplete(URI.from_path(file_uri), content) |
| 88 | + |
| 89 | + # Check that SQL keywords are included |
| 90 | + assert any(k in ["SELECT", "FROM", "WHERE", "JOIN"] for k in completions.keywords) |
| 91 | + |
| 92 | + # Check that file-specific identifiers are included at the end |
| 93 | + keywords_list = completions.keywords |
| 94 | + assert "my_custom_column" in keywords_list |
| 95 | + assert "another_identifier" in keywords_list |
| 96 | + assert "my_custom_table" in keywords_list |
| 97 | + assert "some_other_table" in keywords_list |
| 98 | + assert "mct" in keywords_list # alias |
| 99 | + assert "sot" in keywords_list # alias |
| 100 | + assert "table_id" in keywords_list |
| 101 | + |
| 102 | + # Check that file keywords come after SQL keywords |
| 103 | + # SQL keywords should appear first in the list |
| 104 | + sql_keyword_indices = [ |
| 105 | + i for i, k in enumerate(keywords_list) if k in ["SELECT", "FROM", "WHERE", "JOIN"] |
| 106 | + ] |
| 107 | + file_keyword_indices = [ |
| 108 | + i for i, k in enumerate(keywords_list) if k in ["my_custom_column", "my_custom_table"] |
| 109 | + ] |
| 110 | + |
| 111 | + if sql_keyword_indices and file_keyword_indices: |
| 112 | + assert max(sql_keyword_indices) < min(file_keyword_indices), ( |
| 113 | + "SQL keywords should come before file keywords" |
| 114 | + ) |
| 115 | + |
| 116 | + |
| 117 | +def test_get_sql_completions_with_partial_cte_query(): |
| 118 | + context = Context(paths=["examples/sushi"]) |
| 119 | + lsp_context = LSPContext(context) |
| 120 | + |
| 121 | + # Partial SQL query with CTEs |
| 122 | + content = """ |
| 123 | + WITH _latest_complete_month AS ( |
| 124 | + SELECT MAX(date_trunc('month', order_date)) as month |
| 125 | + FROM orders |
| 126 | + ), |
| 127 | + _filtered AS ( |
| 128 | + SELECT * FROM |
| 129 | + """ |
| 130 | + |
| 131 | + file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql") |
| 132 | + completions = lsp_context.get_autocomplete(URI.from_path(file_uri), content) |
| 133 | + |
| 134 | + # Check that CTE names are included in the keywords |
| 135 | + keywords_list = completions.keywords |
| 136 | + assert "_latest_complete_month" in keywords_list |
| 137 | + assert "_filtered" in keywords_list |
| 138 | + |
| 139 | + # Also check other identifiers from the partial query |
| 140 | + assert "month" in keywords_list |
| 141 | + assert "order_date" in keywords_list |
| 142 | + assert "orders" in keywords_list |
| 143 | + |
| 144 | + |
| 145 | +def test_extract_keywords_from_partial_query(): |
| 146 | + # Test extracting keywords from an incomplete SQL query |
| 147 | + content = """ |
| 148 | + WITH cte1 AS ( |
| 149 | + SELECT col1, col2 FROM table1 |
| 150 | + ), |
| 151 | + cte2 AS ( |
| 152 | + SELECT * FROM cte1 WHERE |
| 153 | + """ |
| 154 | + |
| 155 | + keywords = extract_keywords_from_content(content) |
| 156 | + |
| 157 | + # Check that CTEs are extracted |
| 158 | + assert "cte1" in keywords |
| 159 | + assert "cte2" in keywords |
| 160 | + |
| 161 | + # Check that columns and tables are extracted |
| 162 | + assert "col1" in keywords |
| 163 | + assert "col2" in keywords |
| 164 | + assert "table1" in keywords |
0 commit comments