|
| 1 | +package earlyanalysis_test |
| 2 | + |
| 3 | +import ( |
| 4 | + "testing" |
| 5 | + |
| 6 | + "github.com/stackql/stackql-parser/go/vt/sqlparser" |
| 7 | + "github.com/stretchr/testify/assert" |
| 8 | + "github.com/stretchr/testify/require" |
| 9 | +) |
| 10 | + |
| 11 | +func TestCTEParsing(t *testing.T) { |
| 12 | + t.Run("Simple CTE is parsed correctly", func(t *testing.T) { |
| 13 | + query := "WITH cte AS (SELECT id, name FROM users) SELECT * FROM cte" |
| 14 | + stmt, err := sqlparser.Parse(query) |
| 15 | + require.NoError(t, err) |
| 16 | + |
| 17 | + sel, ok := stmt.(*sqlparser.Select) |
| 18 | + require.True(t, ok, "Statement should be a SELECT") |
| 19 | + |
| 20 | + // Check that With clause exists |
| 21 | + require.NotNil(t, sel.With, "WITH clause should exist") |
| 22 | + require.Len(t, sel.With.CTEs, 1, "Should have 1 CTE") |
| 23 | + |
| 24 | + // Check CTE name |
| 25 | + cte := sel.With.CTEs[0] |
| 26 | + assert.Equal(t, "cte", cte.Name.GetRawVal(), "CTE name should be 'cte'") |
| 27 | + |
| 28 | + // Check that CTE has a select statement |
| 29 | + require.NotNil(t, cte.Select, "CTE should have a select statement") |
| 30 | + }) |
| 31 | + |
| 32 | + t.Run("Multiple CTEs are parsed correctly", func(t *testing.T) { |
| 33 | + query := "WITH a AS (SELECT 1 as x), b AS (SELECT 2 as y) SELECT * FROM a, b" |
| 34 | + stmt, err := sqlparser.Parse(query) |
| 35 | + require.NoError(t, err) |
| 36 | + |
| 37 | + sel, ok := stmt.(*sqlparser.Select) |
| 38 | + require.True(t, ok, "Statement should be a SELECT") |
| 39 | + |
| 40 | + // Check that With clause exists |
| 41 | + require.NotNil(t, sel.With, "WITH clause should exist") |
| 42 | + require.Len(t, sel.With.CTEs, 2, "Should have 2 CTEs") |
| 43 | + |
| 44 | + // Check CTE names |
| 45 | + assert.Equal(t, "a", sel.With.CTEs[0].Name.GetRawVal(), "First CTE name should be 'a'") |
| 46 | + assert.Equal(t, "b", sel.With.CTEs[1].Name.GetRawVal(), "Second CTE name should be 'b'") |
| 47 | + }) |
| 48 | + |
| 49 | + t.Run("Recursive CTE is parsed correctly", func(t *testing.T) { |
| 50 | + query := "WITH RECURSIVE cte AS (SELECT 1 as n UNION ALL SELECT n + 1 FROM cte WHERE n < 10) SELECT * FROM cte" |
| 51 | + stmt, err := sqlparser.Parse(query) |
| 52 | + require.NoError(t, err) |
| 53 | + |
| 54 | + sel, ok := stmt.(*sqlparser.Select) |
| 55 | + require.True(t, ok, "Statement should be a SELECT") |
| 56 | + |
| 57 | + // Check that With clause exists with Recursive flag |
| 58 | + require.NotNil(t, sel.With, "WITH clause should exist") |
| 59 | + assert.True(t, sel.With.Recursive, "WITH clause should be RECURSIVE") |
| 60 | + require.Len(t, sel.With.CTEs, 1, "Should have 1 CTE") |
| 61 | + |
| 62 | + // Check CTE name |
| 63 | + assert.Equal(t, "cte", sel.With.CTEs[0].Name.GetRawVal(), "CTE name should be 'cte'") |
| 64 | + }) |
| 65 | + |
| 66 | + t.Run("CTE with column aliases", func(t *testing.T) { |
| 67 | + query := "WITH cte(col1, col2) AS (SELECT id, name FROM users) SELECT * FROM cte" |
| 68 | + stmt, err := sqlparser.Parse(query) |
| 69 | + require.NoError(t, err) |
| 70 | + |
| 71 | + sel, ok := stmt.(*sqlparser.Select) |
| 72 | + require.True(t, ok, "Statement should be a SELECT") |
| 73 | + |
| 74 | + require.NotNil(t, sel.With, "WITH clause should exist") |
| 75 | + require.Len(t, sel.With.CTEs, 1, "Should have 1 CTE") |
| 76 | + |
| 77 | + cte := sel.With.CTEs[0] |
| 78 | + assert.Equal(t, "cte", cte.Name.GetRawVal(), "CTE name should be 'cte'") |
| 79 | + |
| 80 | + // Check column aliases if present |
| 81 | + require.Len(t, cte.Columns, 2, "CTE should have 2 column aliases") |
| 82 | + assert.Equal(t, "col1", cte.Columns[0].GetRawVal(), "First column alias should be 'col1'") |
| 83 | + assert.Equal(t, "col2", cte.Columns[1].GetRawVal(), "Second column alias should be 'col2'") |
| 84 | + }) |
| 85 | + |
| 86 | + t.Run("Nested CTEs - CTE referencing another CTE", func(t *testing.T) { |
| 87 | + query := "WITH a AS (SELECT 1 as x), b AS (SELECT x * 2 as y FROM a) SELECT * FROM b" |
| 88 | + stmt, err := sqlparser.Parse(query) |
| 89 | + require.NoError(t, err) |
| 90 | + |
| 91 | + sel, ok := stmt.(*sqlparser.Select) |
| 92 | + require.True(t, ok, "Statement should be a SELECT") |
| 93 | + |
| 94 | + require.NotNil(t, sel.With, "WITH clause should exist") |
| 95 | + require.Len(t, sel.With.CTEs, 2, "Should have 2 CTEs") |
| 96 | + }) |
| 97 | +} |
| 98 | + |
| 99 | +func TestCTERegistry(t *testing.T) { |
| 100 | + t.Run("CTE registry stores CTEs correctly", func(t *testing.T) { |
| 101 | + registry := make(map[string]*sqlparser.CommonTableExpr) |
| 102 | + |
| 103 | + query := "WITH cte1 AS (SELECT 1), cte2 AS (SELECT 2) SELECT * FROM cte1, cte2" |
| 104 | + stmt, err := sqlparser.Parse(query) |
| 105 | + require.NoError(t, err) |
| 106 | + |
| 107 | + sel := stmt.(*sqlparser.Select) |
| 108 | + require.NotNil(t, sel.With) |
| 109 | + |
| 110 | + // Simulate what the visitor does - register CTEs |
| 111 | + for _, cte := range sel.With.CTEs { |
| 112 | + cteName := cte.Name.GetRawVal() |
| 113 | + registry[cteName] = cte |
| 114 | + } |
| 115 | + |
| 116 | + // Verify registry contents |
| 117 | + assert.Len(t, registry, 2, "Registry should have 2 CTEs") |
| 118 | + assert.Contains(t, registry, "cte1", "Registry should contain 'cte1'") |
| 119 | + assert.Contains(t, registry, "cte2", "Registry should contain 'cte2'") |
| 120 | + }) |
| 121 | + |
| 122 | + t.Run("CTE lookup works correctly", func(t *testing.T) { |
| 123 | + registry := make(map[string]*sqlparser.CommonTableExpr) |
| 124 | + |
| 125 | + query := "WITH my_cte AS (SELECT id, name FROM users) SELECT * FROM my_cte" |
| 126 | + stmt, err := sqlparser.Parse(query) |
| 127 | + require.NoError(t, err) |
| 128 | + |
| 129 | + sel := stmt.(*sqlparser.Select) |
| 130 | + require.NotNil(t, sel.With) |
| 131 | + |
| 132 | + // Register the CTE |
| 133 | + for _, cte := range sel.With.CTEs { |
| 134 | + cteName := cte.Name.GetRawVal() |
| 135 | + registry[cteName] = cte |
| 136 | + } |
| 137 | + |
| 138 | + // Verify we can look up the CTE |
| 139 | + _, isCTE := registry["my_cte"] |
| 140 | + assert.True(t, isCTE, "'my_cte' should be found in registry") |
| 141 | + |
| 142 | + // Verify non-CTE names are not found |
| 143 | + _, isNotCTE := registry["users"] |
| 144 | + assert.False(t, isNotCTE, "'users' should not be found in registry") |
| 145 | + }) |
| 146 | +} |
| 147 | + |
| 148 | +func TestWindowFunctionParsing(t *testing.T) { |
| 149 | + t.Run("Window function with OVER clause is parsed correctly", func(t *testing.T) { |
| 150 | + query := "SELECT ROW_NUMBER() OVER (ORDER BY id) as row_num FROM t" |
| 151 | + stmt, err := sqlparser.Parse(query) |
| 152 | + require.NoError(t, err) |
| 153 | + |
| 154 | + sel, ok := stmt.(*sqlparser.Select) |
| 155 | + require.True(t, ok, "Statement should be a SELECT") |
| 156 | + |
| 157 | + require.Len(t, sel.SelectExprs, 1, "Should have 1 select expression") |
| 158 | + |
| 159 | + aliased, ok := sel.SelectExprs[0].(*sqlparser.AliasedExpr) |
| 160 | + require.True(t, ok, "Select expression should be aliased") |
| 161 | + |
| 162 | + funcExpr, ok := aliased.Expr.(*sqlparser.FuncExpr) |
| 163 | + require.True(t, ok, "Expression should be a FuncExpr") |
| 164 | + |
| 165 | + assert.Equal(t, "row_number", funcExpr.Name.Lowered(), "Function name should be 'row_number'") |
| 166 | + assert.NotNil(t, funcExpr.Over, "FuncExpr should have OVER clause") |
| 167 | + }) |
| 168 | + |
| 169 | + t.Run("Window function with PARTITION BY is parsed correctly", func(t *testing.T) { |
| 170 | + query := "SELECT SUM(amount) OVER (PARTITION BY category ORDER BY date) as running_sum FROM t" |
| 171 | + stmt, err := sqlparser.Parse(query) |
| 172 | + require.NoError(t, err) |
| 173 | + |
| 174 | + sel := stmt.(*sqlparser.Select) |
| 175 | + aliased := sel.SelectExprs[0].(*sqlparser.AliasedExpr) |
| 176 | + funcExpr := aliased.Expr.(*sqlparser.FuncExpr) |
| 177 | + |
| 178 | + assert.Equal(t, "sum", funcExpr.Name.Lowered()) |
| 179 | + require.NotNil(t, funcExpr.Over, "FuncExpr should have OVER clause") |
| 180 | + |
| 181 | + // Check partition by exists |
| 182 | + require.NotNil(t, funcExpr.Over.PartitionBy, "OVER clause should have PARTITION BY") |
| 183 | + }) |
| 184 | + |
| 185 | + t.Run("Window function with frame specification", func(t *testing.T) { |
| 186 | + query := "SELECT SUM(value) OVER (ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as cumsum FROM t" |
| 187 | + stmt, err := sqlparser.Parse(query) |
| 188 | + require.NoError(t, err) |
| 189 | + |
| 190 | + sel := stmt.(*sqlparser.Select) |
| 191 | + aliased := sel.SelectExprs[0].(*sqlparser.AliasedExpr) |
| 192 | + funcExpr := aliased.Expr.(*sqlparser.FuncExpr) |
| 193 | + |
| 194 | + assert.NotNil(t, funcExpr.Over, "FuncExpr should have OVER clause") |
| 195 | + }) |
| 196 | + |
| 197 | + t.Run("Multiple window functions in query", func(t *testing.T) { |
| 198 | + query := "SELECT ROW_NUMBER() OVER (ORDER BY id) as rn, RANK() OVER (ORDER BY score DESC) as rank FROM t" |
| 199 | + stmt, err := sqlparser.Parse(query) |
| 200 | + require.NoError(t, err) |
| 201 | + |
| 202 | + sel := stmt.(*sqlparser.Select) |
| 203 | + require.Len(t, sel.SelectExprs, 2, "Should have 2 select expressions") |
| 204 | + |
| 205 | + // Check first window function |
| 206 | + aliased1 := sel.SelectExprs[0].(*sqlparser.AliasedExpr) |
| 207 | + funcExpr1 := aliased1.Expr.(*sqlparser.FuncExpr) |
| 208 | + assert.NotNil(t, funcExpr1.Over, "First FuncExpr should have OVER clause") |
| 209 | + |
| 210 | + // Check second window function |
| 211 | + aliased2 := sel.SelectExprs[1].(*sqlparser.AliasedExpr) |
| 212 | + funcExpr2 := aliased2.Expr.(*sqlparser.FuncExpr) |
| 213 | + assert.NotNil(t, funcExpr2.Over, "Second FuncExpr should have OVER clause") |
| 214 | + }) |
| 215 | + |
| 216 | + t.Run("Regular function without OVER clause", func(t *testing.T) { |
| 217 | + query := "SELECT UPPER(name) as upper_name FROM t" |
| 218 | + stmt, err := sqlparser.Parse(query) |
| 219 | + require.NoError(t, err) |
| 220 | + |
| 221 | + sel := stmt.(*sqlparser.Select) |
| 222 | + aliased := sel.SelectExprs[0].(*sqlparser.AliasedExpr) |
| 223 | + funcExpr := aliased.Expr.(*sqlparser.FuncExpr) |
| 224 | + |
| 225 | + assert.Equal(t, "upper", funcExpr.Name.Lowered()) |
| 226 | + assert.Nil(t, funcExpr.Over, "UPPER() should not have OVER clause") |
| 227 | + }) |
| 228 | +} |
0 commit comments