Skip to content

Commit 63525c7

Browse files
committed
Convert CTEs to subqueries at AST level for proper execution
Follow the stackql-devel pattern: CTEs are converted to Subqueries at AST level in ast_expand.go, making downstream code treat them uniformly with regular subqueries. Key changes: - Change cteRegistry to store *sqlparser.Subquery instead of CommonTableExpr - Add processCTEReference() which replaces TableName with Subquery in AST - Handle CTE detection in AliasedTableExpr case (not TableName case) - Remove CTEType-specific handling from from_rewrite.go, parameter_router.go, obtain_context.go, and dependencyplanner.go since SubqueryType path handles them - Add logging for debugging CTE processing
1 parent feac75c commit 63525c7

File tree

6 files changed

+71
-61
lines changed

6 files changed

+71
-61
lines changed

internal/stackql/astanalysis/earlyanalysis/ast_expand.go

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ type indirectExpandAstVisitor struct {
4949
selectCount int
5050
mutateCount int
5151
createBuilder []primitivebuilder.Builder
52-
cteRegistry map[string]*sqlparser.CommonTableExpr
52+
cteRegistry map[string]*sqlparser.Subquery // CTE name -> subquery definition
5353
}
5454

5555
func newIndirectExpandAstVisitor(
@@ -76,7 +76,7 @@ func newIndirectExpandAstVisitor(
7676
tcc: tcc,
7777
whereParams: whereParams,
7878
indirectionDepth: indirectionDepth,
79-
cteRegistry: make(map[string]*sqlparser.CommonTableExpr),
79+
cteRegistry: make(map[string]*sqlparser.Subquery),
8080
}
8181
return rv, nil
8282
}
@@ -105,6 +105,41 @@ func (v *indirectExpandAstVisitor) processMaterializedView(
105105
return nil
106106
}
107107

108+
// processCTEReference handles CTE references by converting them to subquery indirects.
109+
// Returns true if the table name was a CTE reference and was processed, false otherwise.
110+
func (v *indirectExpandAstVisitor) processCTEReference(
111+
node *sqlparser.AliasedTableExpr,
112+
tableName string,
113+
) bool {
114+
cteSubquery, isCTE := v.cteRegistry[tableName]
115+
if !isCTE {
116+
return false
117+
}
118+
logging.GetLogger().Infof("processCTEReference: Converting CTE '%s' to subquery", tableName)
119+
logging.GetLogger().Debugf("processCTEReference: CTE subquery = %s", sqlparser.String(cteSubquery))
120+
// Modify the original node to replace the TableName with the CTE subquery
121+
// This is critical: downstream code (GetHIDs) checks node.Expr type to identify subqueries
122+
node.Expr = cteSubquery
123+
// Set the alias to the CTE name if no explicit alias was provided
124+
if node.As.IsEmpty() {
125+
node.As = sqlparser.NewTableIdent(tableName)
126+
}
127+
logging.GetLogger().Debugf("processCTEReference: Node alias set to '%s'", node.As.GetRawVal())
128+
sq := internaldto.NewSubqueryDTO(node, cteSubquery)
129+
indirect, err := astindirect.NewSubqueryIndirect(sq)
130+
if err != nil {
131+
logging.GetLogger().Errorf("processCTEReference: Failed to create subquery indirect: %v", err)
132+
return true //nolint:nilerr //TODO: investigate
133+
}
134+
err = v.processIndirect(node, indirect)
135+
if err != nil {
136+
logging.GetLogger().Errorf("processCTEReference: processIndirect failed: %v", err)
137+
} else {
138+
logging.GetLogger().Infof("processCTEReference: Successfully processed CTE '%s' as subquery", tableName)
139+
}
140+
return true
141+
}
142+
108143
func (v *indirectExpandAstVisitor) processIndirect(node sqlparser.SQLNode, indirect astindirect.Indirect) error {
109144
err := indirect.Parse()
110145
if err != nil {
@@ -216,14 +251,14 @@ func (v *indirectExpandAstVisitor) Visit(node sqlparser.SQLNode) error {
216251
addIf(node.StraightJoinHint, sqlparser.StraightJoinHint)
217252
addIf(node.SQLCalcFoundRows, sqlparser.SQLCalcFoundRowsStr)
218253

219-
// Register CTEs (Common Table Expressions) if present.
220-
// We only register CTE names here - the actual SELECT processing
221-
// happens in processIndirect when the CTE is referenced.
222-
// This mirrors how subqueries work (processed once in processIndirect).
254+
// Extract CTEs from WITH clause and store in registry as Subqueries.
255+
// CTEs are converted to subqueries at the AST level for uniform handling.
223256
if node.With != nil {
257+
logging.GetLogger().Infof("Registering %d CTEs from WITH clause", len(node.With.CTEs))
224258
for _, cte := range node.With.CTEs {
225259
cteName := cte.Name.GetRawVal()
226-
v.cteRegistry[cteName] = cte
260+
v.cteRegistry[cteName] = cte.Subquery
261+
logging.GetLogger().Debugf("Registered CTE '%s' with subquery: %s", cteName, sqlparser.String(cte.Subquery))
227262
}
228263
}
229264

@@ -798,6 +833,11 @@ func (v *indirectExpandAstVisitor) Visit(node sqlparser.SQLNode) error {
798833
return nil //nolint:nilerr //TODO: investigate
799834
}
800835
return nil
836+
case sqlparser.TableName:
837+
// Check if this is a CTE reference - convert to subquery
838+
if v.processCTEReference(node, n.GetRawVal()) {
839+
return nil
840+
}
801841
}
802842
err := node.Expr.Accept(v)
803843
if err != nil {
@@ -835,22 +875,9 @@ func (v *indirectExpandAstVisitor) Visit(node sqlparser.SQLNode) error {
835875
if node.IsEmpty() {
836876
return nil
837877
}
838-
// Check if this is a CTE reference.
839-
cteName := node.Name.GetRawVal()
840-
if cte, isCTE := v.cteRegistry[cteName]; isCTE {
841-
// This is a CTE reference - create and process an indirect.
842-
// We need to call processIndirect to get the proper column context
843-
// from the CTE's SELECT statement.
844-
indirect, err := astindirect.NewCTEIndirect(cte)
845-
if err != nil {
846-
return err
847-
}
848-
err = v.processIndirect(node, indirect)
849-
if err != nil {
850-
return err
851-
}
852-
return nil
853-
}
878+
// Note: CTE references are handled in AliasedTableExpr case above,
879+
// where they are converted to subqueries. This case only handles
880+
// regular table names (provider.service.resource).
854881
containsBackendMaterial := v.handlerCtx.GetDBMSInternalRouter().ExprIsRoutable(node)
855882
if containsBackendMaterial {
856883
v.containsNativeBackendMaterial = true

internal/stackql/astvisit/from_rewrite.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -668,15 +668,11 @@ func (v *standardFromRewriteAstVisitor) Visit(node sqlparser.SQLNode) error {
668668
v.rewrittenQuery = templateString
669669
v.indirectContexts = append(v.indirectContexts, indirect.GetSelectContext())
670670
case astindirect.SubqueryType:
671+
// Note: CTEs are converted to SubqueryType at AST level,
672+
// so this path handles both regular subqueries and CTEs.
671673
templateString := ` ( %s ) `
672674
v.rewrittenQuery = templateString
673675
v.indirectContexts = append(v.indirectContexts, indirect.GetSelectContext())
674-
case astindirect.CTEType:
675-
// CTEs are handled like views - the inner SELECT is executed
676-
// and results are wrapped with the CTE name as alias.
677-
templateString := fmt.Sprintf(` ( %%s ) AS "%s" `, name)
678-
v.rewrittenQuery = templateString
679-
v.indirectContexts = append(v.indirectContexts, indirect.GetSelectContext())
680676
case astindirect.MaterializedViewType, astindirect.PhysicalTableType:
681677
refString := fmt.Sprintf(` %s `, name)
682678
isQuoted, _ := regexp.MatchString(`^".*"$`, name)

internal/stackql/dependencyplanner/dependencyplanner.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,8 @@ func (dp *standardDependencyPlanner) Plan() error {
164164
annotation := unit.GetAnnotation()
165165
_, isView := annotation.GetView()
166166
_, isSubquery := annotation.GetSubquery()
167-
// Check if this is a CTE reference.
168-
indirect, hasIndirect := annotation.GetTableMeta().GetIndirect()
169-
isCTE := hasIndirect && indirect != nil && indirect.GetType() == astindirect.CTEType
170-
if isView || isSubquery || isCTE {
167+
// Note: CTEs are converted to subqueries at AST level, so isSubquery handles them.
168+
if isView || isSubquery {
171169
dp.annMap[tableExpr] = annotation
172170
continue
173171
}

internal/stackql/primitivegenerator/select.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ func (pb *standardPrimitiveGenerator) analyzeSelect(pbi planbuilderinput.PlanBui
224224
selCtx := dp.GetSelectCtx()
225225
pChild.GetPrimitiveComposer().SetBuilder(bld)
226226
pb.PrimitiveComposer.SetSelectPreparedStatementCtx(selCtx)
227-
// For indirects (like CTEs), also set the builder on the indirect's primitive
228-
// composer so it can be found when parent's GetBuilder() iterates over indirects.
227+
// For indirects (subqueries, views, etc.), also set the builder on the indirect's
228+
// primitive composer so it can be found when parent's GetBuilder() iterates.
229229
if pb.PrimitiveComposer.IsIndirect() {
230230
pb.PrimitiveComposer.SetBuilder(bld)
231231
}

internal/stackql/router/obtain_context/obtain_context.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package obtain_context //nolint:revive,cyclop,stylecheck // TODO: allow
33
import (
44
"fmt"
55

6-
"github.com/stackql/stackql/internal/stackql/astindirect"
76
"github.com/stackql/stackql/internal/stackql/internal_data_transfer/internaldto"
87
"github.com/stackql/stackql/internal/stackql/sql_system"
98
"github.com/stackql/stackql/internal/stackql/tablemetadata"
@@ -21,10 +20,8 @@ func ObtainAnnotationCtx(
2120
_, isSQLDataSource := tbl.GetSQLDataSource()
2221
_, isSubquery := tbl.GetSubquery()
2322
isPGInternalObject := tbl.GetHeirarchyObjects().IsPGInternalObject()
24-
// Check if this is a CTE reference.
25-
indirect, hasIndirect := tbl.GetIndirect()
26-
isCTE := hasIndirect && indirect != nil && indirect.GetType() == astindirect.CTEType
27-
if isView || isSQLDataSource || isSubquery || isPGInternalObject || isCTE {
23+
// Note: CTEs are converted to subqueries at AST level, so isSubquery handles them.
24+
if isView || isSQLDataSource || isSubquery || isPGInternalObject {
2825
// TODO: upgrade this flow; nil == YUCK!!!
2926
return taxonomy.NewStaticStandardAnnotationCtx(
3027
nil, tbl.GetHeirarchyObjects().GetHeirarchyIDs(),

internal/stackql/router/parameter_router.go

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -592,30 +592,22 @@ func (pr *standardParameterRouter) route(
592592
priorParameters := runParamters.Clone()
593593
// notOnStringParams := notOnParams.GetStringified()
594594
// TODO: add parent params into the mix here.
595-
// Check for CTE indirect before attempting hierarchy resolution.
596-
// CTEs don't need provider/service/resource resolution.
597-
cteIndirect, isCTE := pr.annotatedAST.GetIndirect(tb)
595+
// Note: CTEs are converted to subqueries at AST level, so they follow
596+
// the normal subquery path (handled via GetSubquery() in ObtainAnnotationCtx).
598597
var hr tablemetadata.HeirarchyObjects
599598
var err error
600-
if isCTE && cteIndirect != nil && cteIndirect.GetType() == astindirect.CTEType {
601-
// CTE reference - create empty hierarchy identifiers.
602-
tableName := taxonomy.GetTableNameFromStatement(tb, pr.astFormatter)
603-
hIDs := internaldto.NewHeirarchyIdentifiers("", "", tableName, "")
604-
hr = tablemetadata.NewHeirarchyObjects(hIDs, isAwait)
599+
hr, err = taxonomy.GetHeirarchyFromStatement(handlerCtx, tb, notOnParams, false, isAwait)
600+
if err != nil {
601+
hr, err = taxonomy.GetHeirarchyFromStatement(handlerCtx, tb, runParamters, false, isAwait)
605602
} else {
606-
hr, err = taxonomy.GetHeirarchyFromStatement(handlerCtx, tb, notOnParams, false, isAwait)
607-
if err != nil {
608-
hr, err = taxonomy.GetHeirarchyFromStatement(handlerCtx, tb, runParamters, false, isAwait)
609-
} else {
610-
// If the where parameters are sufficient, then need to switch
611-
// the Table - Paramater coupling object
612-
runParamters = notOnParams
613-
priorParameters = priorNotOnParameters
614-
}
615-
// logging.GetLogger().Infof("hr = '%+v', remainingParams = '%+v', err = '%+v'", hr, remainingParams, err)
616-
if err != nil {
617-
return nil, err
618-
}
603+
// If the where parameters are sufficient, then need to switch
604+
// the Table - Paramater coupling object
605+
runParamters = notOnParams
606+
priorParameters = priorNotOnParameters
607+
}
608+
// logging.GetLogger().Infof("hr = '%+v', remainingParams = '%+v', err = '%+v'", hr, remainingParams, err)
609+
if err != nil {
610+
return nil, err
619611
}
620612
// reconstitutedConsumedParams, err := tpc.ReconstituteConsumedParams(remainingParams)
621613
// if err != nil {

0 commit comments

Comments
 (0)