Skip to content

Commit 07cc7a5

Browse files
authored
Merge pull request #594 from stackql/claude/fix-stackql-tests-01FaQ7YN6kjqRvqDwtymJdDR
Add window function (OVER clause) and CTE (WITH clause)
2 parents 6825d5d + 8979e9f commit 07cc7a5

File tree

26 files changed

+1128
-2
lines changed

26 files changed

+1128
-2
lines changed

.github/workflows/lint.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ jobs:
3333
with:
3434
args: -color
3535

36+
- name: Update Go dependencies
37+
run: go mod tidy
3638

3739
- name: Run golangci-lint
3840
uses: golangci/golangci-lint-action@v8.0.0

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ require (
2020
github.com/stackql/any-sdk v0.3.1-beta01
2121
github.com/stackql/go-suffix-map v0.0.1-alpha01
2222
github.com/stackql/psql-wire v0.1.2-alpha01
23-
github.com/stackql/stackql-parser v0.0.15-alpha06
23+
github.com/stackql/stackql-parser v0.0.16-alpha01
2424
github.com/stretchr/testify v1.10.0
2525
golang.org/x/sync v0.15.0
2626
gonum.org/v1/gonum v0.15.1

internal/stackql/astanalysis/earlyanalysis/ast_expand.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ type indirectExpandAstVisitor struct {
4949
selectCount int
5050
mutateCount int
5151
createBuilder []primitivebuilder.Builder
52+
cteRegistry map[string]*sqlparser.Subquery // CTE name -> subquery definition
5253
}
5354

5455
func newIndirectExpandAstVisitor(
@@ -75,6 +76,7 @@ func newIndirectExpandAstVisitor(
7576
tcc: tcc,
7677
whereParams: whereParams,
7778
indirectionDepth: indirectionDepth,
79+
cteRegistry: make(map[string]*sqlparser.Subquery),
7880
}
7981
return rv, nil
8082
}
@@ -103,6 +105,41 @@ func (v *indirectExpandAstVisitor) processMaterializedView(
103105
return nil
104106
}
105107

108+
// processCTEReference handles CTE references by converting them to subquery indirects.
109+
// Returns true if the table name was a CTE reference and was processed, false otherwise.
110+
func (v *indirectExpandAstVisitor) processCTEReference(
111+
node *sqlparser.AliasedTableExpr,
112+
tableName string,
113+
) bool {
114+
cteSubquery, isCTE := v.cteRegistry[tableName]
115+
if !isCTE {
116+
return false
117+
}
118+
logging.GetLogger().Infof("processCTEReference: Converting CTE '%s' to subquery", tableName)
119+
logging.GetLogger().Debugf("processCTEReference: CTE subquery = %s", sqlparser.String(cteSubquery))
120+
// Modify the original node to replace the TableName with the CTE subquery
121+
// This is critical: downstream code (GetHIDs) checks node.Expr type to identify subqueries
122+
node.Expr = cteSubquery
123+
// Set the alias to the CTE name if no explicit alias was provided
124+
if node.As.IsEmpty() {
125+
node.As = sqlparser.NewTableIdent(tableName)
126+
}
127+
logging.GetLogger().Debugf("processCTEReference: Node alias set to '%s'", node.As.GetRawVal())
128+
sq := internaldto.NewSubqueryDTO(node, cteSubquery)
129+
indirect, err := astindirect.NewSubqueryIndirect(sq)
130+
if err != nil {
131+
logging.GetLogger().Errorf("processCTEReference: Failed to create subquery indirect: %v", err)
132+
return true //nolint:nilerr //TODO: investigate
133+
}
134+
err = v.processIndirect(node, indirect)
135+
if err != nil {
136+
logging.GetLogger().Errorf("processCTEReference: processIndirect failed: %v", err)
137+
} else {
138+
logging.GetLogger().Infof("processCTEReference: Successfully processed CTE '%s' as subquery", tableName)
139+
}
140+
return true
141+
}
142+
106143
func (v *indirectExpandAstVisitor) processIndirect(node sqlparser.SQLNode, indirect astindirect.Indirect) error {
107144
err := indirect.Parse()
108145
if err != nil {
@@ -214,6 +251,19 @@ func (v *indirectExpandAstVisitor) Visit(node sqlparser.SQLNode) error {
214251
addIf(node.StraightJoinHint, sqlparser.StraightJoinHint)
215252
addIf(node.SQLCalcFoundRows, sqlparser.SQLCalcFoundRowsStr)
216253

254+
// Extract CTEs from WITH clause and store in registry as Subqueries.
255+
// CTEs are converted to subqueries at the AST level for uniform handling.
256+
if node.With != nil {
257+
logging.GetLogger().Infof("Registering %d CTEs from WITH clause", len(node.With.CTEs))
258+
for _, cte := range node.With.CTEs {
259+
cteName := cte.Name.GetRawVal()
260+
// Wrap the CTE's SELECT statement in a Subquery struct
261+
cteSubquery := &sqlparser.Subquery{Select: cte.Select}
262+
v.cteRegistry[cteName] = cteSubquery
263+
logging.GetLogger().Debugf("Registered CTE '%s' with subquery: %s", cteName, sqlparser.String(cteSubquery))
264+
}
265+
}
266+
217267
if node.Comments != nil {
218268
node.Comments.Accept(v) //nolint:errcheck // future proof
219269
}
@@ -785,6 +835,11 @@ func (v *indirectExpandAstVisitor) Visit(node sqlparser.SQLNode) error {
785835
return nil //nolint:nilerr //TODO: investigate
786836
}
787837
return nil
838+
case sqlparser.TableName:
839+
// Check if this is a CTE reference - convert to subquery
840+
if v.processCTEReference(node, n.GetRawVal()) {
841+
return nil
842+
}
788843
}
789844
err := node.Expr.Accept(v)
790845
if err != nil {
@@ -822,6 +877,9 @@ func (v *indirectExpandAstVisitor) Visit(node sqlparser.SQLNode) error {
822877
if node.IsEmpty() {
823878
return nil
824879
}
880+
// Note: CTE references are handled in AliasedTableExpr case above,
881+
// where they are converted to subqueries. This case only handles
882+
// regular table names (provider.service.resource).
825883
containsBackendMaterial := v.handlerCtx.GetDBMSInternalRouter().ExprIsRoutable(node)
826884
if containsBackendMaterial {
827885
v.containsNativeBackendMaterial = true
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
package earlyanalysis_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stackql/stackql-parser/go/vt/sqlparser"
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestCTEParsing(t *testing.T) {
12+
t.Run("Simple CTE is parsed correctly", func(t *testing.T) {
13+
query := "WITH cte AS (SELECT id, name FROM users) SELECT * FROM cte"
14+
stmt, err := sqlparser.Parse(query)
15+
require.NoError(t, err)
16+
17+
sel, ok := stmt.(*sqlparser.Select)
18+
require.True(t, ok, "Statement should be a SELECT")
19+
20+
// Check that With clause exists
21+
require.NotNil(t, sel.With, "WITH clause should exist")
22+
require.Len(t, sel.With.CTEs, 1, "Should have 1 CTE")
23+
24+
// Check CTE name
25+
cte := sel.With.CTEs[0]
26+
assert.Equal(t, "cte", cte.Name.GetRawVal(), "CTE name should be 'cte'")
27+
28+
// Check that CTE has a select statement
29+
require.NotNil(t, cte.Select, "CTE should have a select statement")
30+
})
31+
32+
t.Run("Multiple CTEs are parsed correctly", func(t *testing.T) {
33+
query := "WITH a AS (SELECT 1 as x), b AS (SELECT 2 as y) SELECT * FROM a, b"
34+
stmt, err := sqlparser.Parse(query)
35+
require.NoError(t, err)
36+
37+
sel, ok := stmt.(*sqlparser.Select)
38+
require.True(t, ok, "Statement should be a SELECT")
39+
40+
// Check that With clause exists
41+
require.NotNil(t, sel.With, "WITH clause should exist")
42+
require.Len(t, sel.With.CTEs, 2, "Should have 2 CTEs")
43+
44+
// Check CTE names
45+
assert.Equal(t, "a", sel.With.CTEs[0].Name.GetRawVal(), "First CTE name should be 'a'")
46+
assert.Equal(t, "b", sel.With.CTEs[1].Name.GetRawVal(), "Second CTE name should be 'b'")
47+
})
48+
49+
t.Run("Recursive CTE is parsed correctly", func(t *testing.T) {
50+
query := "WITH RECURSIVE cte AS (SELECT 1 as n UNION ALL SELECT n + 1 FROM cte WHERE n < 10) SELECT * FROM cte"
51+
stmt, err := sqlparser.Parse(query)
52+
require.NoError(t, err)
53+
54+
sel, ok := stmt.(*sqlparser.Select)
55+
require.True(t, ok, "Statement should be a SELECT")
56+
57+
// Check that With clause exists with Recursive flag
58+
require.NotNil(t, sel.With, "WITH clause should exist")
59+
assert.True(t, sel.With.Recursive, "WITH clause should be RECURSIVE")
60+
require.Len(t, sel.With.CTEs, 1, "Should have 1 CTE")
61+
62+
// Check CTE name
63+
assert.Equal(t, "cte", sel.With.CTEs[0].Name.GetRawVal(), "CTE name should be 'cte'")
64+
})
65+
66+
t.Run("CTE with column aliases", func(t *testing.T) {
67+
query := "WITH cte(col1, col2) AS (SELECT id, name FROM users) SELECT * FROM cte"
68+
stmt, err := sqlparser.Parse(query)
69+
require.NoError(t, err)
70+
71+
sel, ok := stmt.(*sqlparser.Select)
72+
require.True(t, ok, "Statement should be a SELECT")
73+
74+
require.NotNil(t, sel.With, "WITH clause should exist")
75+
require.Len(t, sel.With.CTEs, 1, "Should have 1 CTE")
76+
77+
cte := sel.With.CTEs[0]
78+
assert.Equal(t, "cte", cte.Name.GetRawVal(), "CTE name should be 'cte'")
79+
80+
// Check column aliases if present
81+
require.Len(t, cte.Columns, 2, "CTE should have 2 column aliases")
82+
assert.Equal(t, "col1", cte.Columns[0].GetRawVal(), "First column alias should be 'col1'")
83+
assert.Equal(t, "col2", cte.Columns[1].GetRawVal(), "Second column alias should be 'col2'")
84+
})
85+
86+
t.Run("Nested CTEs - CTE referencing another CTE", func(t *testing.T) {
87+
query := "WITH a AS (SELECT 1 as x), b AS (SELECT x * 2 as y FROM a) SELECT * FROM b"
88+
stmt, err := sqlparser.Parse(query)
89+
require.NoError(t, err)
90+
91+
sel, ok := stmt.(*sqlparser.Select)
92+
require.True(t, ok, "Statement should be a SELECT")
93+
94+
require.NotNil(t, sel.With, "WITH clause should exist")
95+
require.Len(t, sel.With.CTEs, 2, "Should have 2 CTEs")
96+
})
97+
}
98+
99+
func TestCTERegistry(t *testing.T) {
100+
t.Run("CTE registry stores CTEs correctly", func(t *testing.T) {
101+
registry := make(map[string]*sqlparser.CommonTableExpr)
102+
103+
query := "WITH cte1 AS (SELECT 1), cte2 AS (SELECT 2) SELECT * FROM cte1, cte2"
104+
stmt, err := sqlparser.Parse(query)
105+
require.NoError(t, err)
106+
107+
sel := stmt.(*sqlparser.Select)
108+
require.NotNil(t, sel.With)
109+
110+
// Simulate what the visitor does - register CTEs
111+
for _, cte := range sel.With.CTEs {
112+
cteName := cte.Name.GetRawVal()
113+
registry[cteName] = cte
114+
}
115+
116+
// Verify registry contents
117+
assert.Len(t, registry, 2, "Registry should have 2 CTEs")
118+
assert.Contains(t, registry, "cte1", "Registry should contain 'cte1'")
119+
assert.Contains(t, registry, "cte2", "Registry should contain 'cte2'")
120+
})
121+
122+
t.Run("CTE lookup works correctly", func(t *testing.T) {
123+
registry := make(map[string]*sqlparser.CommonTableExpr)
124+
125+
query := "WITH my_cte AS (SELECT id, name FROM users) SELECT * FROM my_cte"
126+
stmt, err := sqlparser.Parse(query)
127+
require.NoError(t, err)
128+
129+
sel := stmt.(*sqlparser.Select)
130+
require.NotNil(t, sel.With)
131+
132+
// Register the CTE
133+
for _, cte := range sel.With.CTEs {
134+
cteName := cte.Name.GetRawVal()
135+
registry[cteName] = cte
136+
}
137+
138+
// Verify we can look up the CTE
139+
_, isCTE := registry["my_cte"]
140+
assert.True(t, isCTE, "'my_cte' should be found in registry")
141+
142+
// Verify non-CTE names are not found
143+
_, isNotCTE := registry["users"]
144+
assert.False(t, isNotCTE, "'users' should not be found in registry")
145+
})
146+
}
147+
148+
func TestWindowFunctionParsing(t *testing.T) {
149+
t.Run("Window function with OVER clause is parsed correctly", func(t *testing.T) {
150+
query := "SELECT ROW_NUMBER() OVER (ORDER BY id) as row_num FROM t"
151+
stmt, err := sqlparser.Parse(query)
152+
require.NoError(t, err)
153+
154+
sel, ok := stmt.(*sqlparser.Select)
155+
require.True(t, ok, "Statement should be a SELECT")
156+
157+
require.Len(t, sel.SelectExprs, 1, "Should have 1 select expression")
158+
159+
aliased, ok := sel.SelectExprs[0].(*sqlparser.AliasedExpr)
160+
require.True(t, ok, "Select expression should be aliased")
161+
162+
funcExpr, ok := aliased.Expr.(*sqlparser.FuncExpr)
163+
require.True(t, ok, "Expression should be a FuncExpr")
164+
165+
assert.Equal(t, "row_number", funcExpr.Name.Lowered(), "Function name should be 'row_number'")
166+
assert.NotNil(t, funcExpr.Over, "FuncExpr should have OVER clause")
167+
})
168+
169+
t.Run("Window function with PARTITION BY is parsed correctly", func(t *testing.T) {
170+
query := "SELECT SUM(amount) OVER (PARTITION BY category ORDER BY date) as running_sum FROM t"
171+
stmt, err := sqlparser.Parse(query)
172+
require.NoError(t, err)
173+
174+
sel := stmt.(*sqlparser.Select)
175+
aliased := sel.SelectExprs[0].(*sqlparser.AliasedExpr)
176+
funcExpr := aliased.Expr.(*sqlparser.FuncExpr)
177+
178+
assert.Equal(t, "sum", funcExpr.Name.Lowered())
179+
require.NotNil(t, funcExpr.Over, "FuncExpr should have OVER clause")
180+
181+
// Check partition by exists
182+
require.NotNil(t, funcExpr.Over.PartitionBy, "OVER clause should have PARTITION BY")
183+
})
184+
185+
t.Run("Window function with frame specification", func(t *testing.T) {
186+
query := "SELECT SUM(value) OVER (ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as cumsum FROM t"
187+
stmt, err := sqlparser.Parse(query)
188+
require.NoError(t, err)
189+
190+
sel := stmt.(*sqlparser.Select)
191+
aliased := sel.SelectExprs[0].(*sqlparser.AliasedExpr)
192+
funcExpr := aliased.Expr.(*sqlparser.FuncExpr)
193+
194+
assert.NotNil(t, funcExpr.Over, "FuncExpr should have OVER clause")
195+
})
196+
197+
t.Run("Multiple window functions in query", func(t *testing.T) {
198+
query := "SELECT ROW_NUMBER() OVER (ORDER BY id) as rn, RANK() OVER (ORDER BY score DESC) as rank FROM t"
199+
stmt, err := sqlparser.Parse(query)
200+
require.NoError(t, err)
201+
202+
sel := stmt.(*sqlparser.Select)
203+
require.Len(t, sel.SelectExprs, 2, "Should have 2 select expressions")
204+
205+
// Check first window function
206+
aliased1 := sel.SelectExprs[0].(*sqlparser.AliasedExpr)
207+
funcExpr1 := aliased1.Expr.(*sqlparser.FuncExpr)
208+
assert.NotNil(t, funcExpr1.Over, "First FuncExpr should have OVER clause")
209+
210+
// Check second window function
211+
aliased2 := sel.SelectExprs[1].(*sqlparser.AliasedExpr)
212+
funcExpr2 := aliased2.Expr.(*sqlparser.FuncExpr)
213+
assert.NotNil(t, funcExpr2.Over, "Second FuncExpr should have OVER clause")
214+
})
215+
216+
t.Run("Regular function without OVER clause", func(t *testing.T) {
217+
query := "SELECT UPPER(name) as upper_name FROM t"
218+
stmt, err := sqlparser.Parse(query)
219+
require.NoError(t, err)
220+
221+
sel := stmt.(*sqlparser.Select)
222+
aliased := sel.SelectExprs[0].(*sqlparser.AliasedExpr)
223+
funcExpr := aliased.Expr.(*sqlparser.FuncExpr)
224+
225+
assert.Equal(t, "upper", funcExpr.Name.Lowered())
226+
assert.Nil(t, funcExpr.Over, "UPPER() should not have OVER clause")
227+
})
228+
}

0 commit comments

Comments
 (0)