Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 78 additions & 37 deletions internal/stackql/astanalysis/earlyanalysis/ast_expand.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type indirectExpandAstVisitor struct {
selectCount int
mutateCount int
createBuilder []primitivebuilder.Builder
cteRegistry map[string]*sqlparser.Subquery // CTE name -> subquery definition
}

func newIndirectExpandAstVisitor(
Expand All @@ -75,6 +76,7 @@ func newIndirectExpandAstVisitor(
tcc: tcc,
whereParams: whereParams,
indirectionDepth: indirectionDepth,
cteRegistry: make(map[string]*sqlparser.Subquery),
}
return rv, nil
}
Expand Down Expand Up @@ -178,6 +180,72 @@ func (v *indirectExpandAstVisitor) IsReadOnly() bool {
return v.selectCount > 0 && v.mutateCount == 0
}

// processCTEReference handles CTE references by converting them to subquery indirects.
// Returns true if the table name was a CTE reference and was processed, false otherwise.
func (v *indirectExpandAstVisitor) processCTEReference(
node *sqlparser.AliasedTableExpr,
tableName string,
) bool {
cteSubquery, isCTE := v.cteRegistry[tableName]
if !isCTE {
return false
}
// Modify the original node to replace the TableName with the CTE subquery
// This is critical: downstream code (GetHIDs) checks node.Expr type to identify subqueries
node.Expr = cteSubquery
// Set the alias to the CTE name if no explicit alias was provided
if node.As.IsEmpty() {
node.As = sqlparser.NewTableIdent(tableName)
}
sq := internaldto.NewSubqueryDTO(node, cteSubquery)
indirect, err := astindirect.NewSubqueryIndirect(sq)
if err != nil {
return true //nolint:nilerr //TODO: investigate
}
_ = v.processIndirect(node, indirect) //nolint:errcheck // errors handled via indirect pattern
return true
}

// visitAliasedTableExpr handles visiting an AliasedTableExpr node, including
// subqueries, CTE references, and regular table expressions.
func (v *indirectExpandAstVisitor) visitAliasedTableExpr(node *sqlparser.AliasedTableExpr) error {
if node.Expr != nil {
switch n := node.Expr.(type) { //nolint:gocritic // switch preferred for type assertions
case *sqlparser.Subquery:
sq := internaldto.NewSubqueryDTO(node, n)
indirect, err := astindirect.NewSubqueryIndirect(sq)
if err != nil {
return nil //nolint:nilerr //TODO: investigate
}
_ = v.processIndirect(node, indirect) //nolint:errcheck // errors handled via indirect pattern
return nil
case sqlparser.TableName:
if v.processCTEReference(node, n.GetRawVal()) {
return nil
}
}
if err := node.Expr.Accept(v); err != nil {
return err
}
}
if node.Partitions != nil {
if err := node.Partitions.Accept(v); err != nil {
return err
}
}
if !node.As.IsEmpty() {
if err := node.As.Accept(v); err != nil {
return err
}
}
if node.Hints != nil {
if err := node.Hints.Accept(v); err != nil {
return err
}
}
return nil
}

func (v *indirectExpandAstVisitor) ContainsAnalyticsCacheMaterial() bool {
return v.containsAnalyticsCacheMaterial
}
Expand Down Expand Up @@ -214,6 +282,14 @@ func (v *indirectExpandAstVisitor) Visit(node sqlparser.SQLNode) error {
addIf(node.StraightJoinHint, sqlparser.StraightJoinHint)
addIf(node.SQLCalcFoundRows, sqlparser.SQLCalcFoundRowsStr)

// Extract CTEs from WITH clause and store in registry
if node.With != nil {
for _, cte := range node.With.CTEs {
cteName := cte.Name.GetRawVal()
v.cteRegistry[cteName] = cte.Subquery
}
}

if node.Comments != nil {
node.Comments.Accept(v) //nolint:errcheck // future proof
}
Expand Down Expand Up @@ -771,43 +847,8 @@ func (v *indirectExpandAstVisitor) Visit(node sqlparser.SQLNode) error {
}

case *sqlparser.AliasedTableExpr:
if node.Expr != nil {
//nolint:gocritic // deferring cosmetics on visitors
switch n := node.Expr.(type) {
case *sqlparser.Subquery:
sq := internaldto.NewSubqueryDTO(node, n)
indirect, err := astindirect.NewSubqueryIndirect(sq)
if err != nil {
return nil //nolint:nilerr //TODO: investigate
}
err = v.processIndirect(node, indirect)
if err != nil {
return nil //nolint:nilerr //TODO: investigate
}
return nil
}
err := node.Expr.Accept(v)
if err != nil {
return err
}
}
if node.Partitions != nil {
err := node.Partitions.Accept(v)
if err != nil {
return err
}
}
if !node.As.IsEmpty() {
err := node.As.Accept(v)
if err != nil {
return err
}
}
if node.Hints != nil {
err := node.Hints.Accept(v)
if err != nil {
return err
}
if err := v.visitAliasedTableExpr(node); err != nil {
return err
}

case sqlparser.TableNames:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
|---------------------|-------------------------|
| name | maximumCardsPerInstance |
|---------------------|-------------------------|
| nvidia-tesla-p4 | 4 |
|---------------------|-------------------------|
| nvidia-tesla-p4-vws | 4 |
|---------------------|-------------------------|
| nvidia-tesla-t4 | 4 |
|---------------------|-------------------------|
| nvidia-tesla-t4-vws | 4 |
|---------------------|-------------------------|
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
|---------------------|------|
| name | rank |
|---------------------|------|
| nvidia-tesla-p4 | 1 |
|---------------------|------|
| nvidia-tesla-p4-vws | 1 |
|---------------------|------|
| nvidia-tesla-t4 | 1 |
|---------------------|------|
| nvidia-tesla-t4-vws | 1 |
|---------------------|------|
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
|---------------------|------------|
| name | row_number |
|---------------------|------------|
| nvidia-tesla-p4 | 1 |
|---------------------|------------|
| nvidia-tesla-p4-vws | 2 |
|---------------------|------------|
| nvidia-tesla-t4 | 3 |
|---------------------|------------|
| nvidia-tesla-t4-vws | 4 |
|---------------------|------------|
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
|---------------------|-------------|
| name | running_sum |
|---------------------|-------------|
| nvidia-tesla-p4 | 4 |
|---------------------|-------------|
| nvidia-tesla-p4-vws | 8 |
|---------------------|-------------|
| nvidia-tesla-t4 | 12 |
|---------------------|-------------|
| nvidia-tesla-t4-vws | 16 |
|---------------------|-------------|
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
|---------------------|-----------|
| name | total_sum |
|---------------------|-----------|
| nvidia-tesla-p4 | 16 |
|---------------------|-----------|
| nvidia-tesla-p4-vws | 16 |
|---------------------|-----------|
| nvidia-tesla-t4 | 16 |
|---------------------|-----------|
| nvidia-tesla-t4-vws | 16 |
|---------------------|-----------|
28 changes: 28 additions & 0 deletions test/python/stackql_test_tooling/stackql_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,15 @@ def generate_password() -> str:
SELECT_MACHINE_TYPES_DESC = "select name from google.compute.machineTypes where project = 'testing-project' and zone = 'australia-southeast1-a' order by name desc;"
SELECT_GOOGLE_COMPUTE_INSTANCE_IAM_POLICY = "SELECT etag FROM google.compute.instances_iam_policies WHERE project = 'testing-project' AND zone = 'australia-southeast1-a' AND resource = '000000001';"

# Window function tests using accelerator types
SELECT_ACCELERATOR_TYPES_ROW_NUMBER = "SELECT name, ROW_NUMBER() OVER (ORDER BY name ASC) as row_number FROM google.compute.acceleratorTypes WHERE project = 'testing-project' AND zone = 'australia-southeast1-a' ORDER BY name ASC;"
SELECT_ACCELERATOR_TYPES_SUM_OVER = "SELECT name, SUM(maximumCardsPerInstance) OVER () as total_sum FROM google.compute.acceleratorTypes WHERE project = 'testing-project' AND zone = 'australia-southeast1-a' ORDER BY name ASC;"
SELECT_ACCELERATOR_TYPES_RUNNING_SUM = "SELECT name, SUM(maximumCardsPerInstance) OVER (ORDER BY name ASC) as running_sum FROM google.compute.acceleratorTypes WHERE project = 'testing-project' AND zone = 'australia-southeast1-a' ORDER BY name ASC;"
SELECT_ACCELERATOR_TYPES_RANK = "SELECT name, RANK() OVER (ORDER BY maximumCardsPerInstance) as rank FROM google.compute.acceleratorTypes WHERE project = 'testing-project' AND zone = 'australia-southeast1-a' ORDER BY name ASC;"

# CTE (Common Table Expression) tests using accelerator types
SELECT_ACCELERATOR_TYPES_SIMPLE_CTE = "WITH accel_types AS (SELECT name, maximumCardsPerInstance FROM google.compute.acceleratorTypes WHERE project = 'testing-project' AND zone = 'australia-southeast1-a') SELECT name, maximumCardsPerInstance FROM accel_types ORDER BY name ASC;"

SELECT_AWS_CLOUD_CONTROL_EVENTS_MINIMAL = "SELECT DISTINCT EventTime, Identifier from aws.cloud_control.resource_requests where data__ResourceRequestStatusFilter='{}' and region = 'ap-southeast-1' order by Identifier, EventTime;"

SELECT_AZURE_COMPUTE_PUBLIC_KEYS = "select id, location from azure.compute.ssh_public_keys where subscriptionId = '10001000-1000-1000-1000-100010001000' ORDER BY id ASC;"
Expand Down Expand Up @@ -765,6 +774,15 @@ def get_native_table_count_by_name(table_name :str, sql_backend_str :str) -> str

SELECT_ACCELERATOR_TYPES_DESC_EXPECTED = get_output_from_local_file(os.path.join('test', 'assets', 'expected', 'simple-select', 'compute-accelerator-type', 'select-zone-list-desc.txt'))

# Window function test expected results
SELECT_ACCELERATOR_TYPES_ROW_NUMBER_EXPECTED = get_output_from_local_file(os.path.join('test', 'assets', 'expected', 'simple-select', 'compute-accelerator-type', 'window-functions', 'row-number-by-name.txt'))
SELECT_ACCELERATOR_TYPES_SUM_OVER_EXPECTED = get_output_from_local_file(os.path.join('test', 'assets', 'expected', 'simple-select', 'compute-accelerator-type', 'window-functions', 'sum-over-all.txt'))
SELECT_ACCELERATOR_TYPES_RUNNING_SUM_EXPECTED = get_output_from_local_file(os.path.join('test', 'assets', 'expected', 'simple-select', 'compute-accelerator-type', 'window-functions', 'running-sum-by-name.txt'))
SELECT_ACCELERATOR_TYPES_RANK_EXPECTED = get_output_from_local_file(os.path.join('test', 'assets', 'expected', 'simple-select', 'compute-accelerator-type', 'window-functions', 'rank-by-cards.txt'))

# CTE test expected results
SELECT_ACCELERATOR_TYPES_SIMPLE_CTE_EXPECTED = get_output_from_local_file(os.path.join('test', 'assets', 'expected', 'simple-select', 'compute-accelerator-type', 'cte', 'simple-cte.txt'))

SELECT_MACHINE_TYPES_DESC_EXPECTED = get_output_from_local_file(os.path.join('test', 'assets', 'expected', 'google', 'compute', 'instance-type-list-names-paginated-desc.txt'))

SELECT_OKTA_APPS_ASC_EXPECTED = get_output_from_local_file(os.path.join('test', 'assets', 'expected', 'simple-select', 'okta', 'apps', 'select-apps-asc.txt'))
Expand Down Expand Up @@ -987,6 +1005,16 @@ def get_registry_mock_url(execution_env :str) -> str:
'SELECT_ACCELERATOR_TYPES_DESC': SELECT_ACCELERATOR_TYPES_DESC,
'SELECT_ACCELERATOR_TYPES_DESC_EXPECTED': SELECT_ACCELERATOR_TYPES_DESC_EXPECTED,
'SELECT_ACCELERATOR_TYPES_DESC_SEQUENCE': [ SELECT_ACCELERATOR_TYPES_DESC, SELECT_ACCELERATOR_TYPES_DESC_FROM_INTEL_VIEWS, SELECT_ACCELERATOR_TYPES_DESC_FROM_INTEL_VIEWS_SUBQUERY ],
'SELECT_ACCELERATOR_TYPES_RANK': SELECT_ACCELERATOR_TYPES_RANK,
'SELECT_ACCELERATOR_TYPES_RANK_EXPECTED': SELECT_ACCELERATOR_TYPES_RANK_EXPECTED,
'SELECT_ACCELERATOR_TYPES_ROW_NUMBER': SELECT_ACCELERATOR_TYPES_ROW_NUMBER,
'SELECT_ACCELERATOR_TYPES_ROW_NUMBER_EXPECTED': SELECT_ACCELERATOR_TYPES_ROW_NUMBER_EXPECTED,
'SELECT_ACCELERATOR_TYPES_RUNNING_SUM': SELECT_ACCELERATOR_TYPES_RUNNING_SUM,
'SELECT_ACCELERATOR_TYPES_RUNNING_SUM_EXPECTED': SELECT_ACCELERATOR_TYPES_RUNNING_SUM_EXPECTED,
'SELECT_ACCELERATOR_TYPES_SUM_OVER': SELECT_ACCELERATOR_TYPES_SUM_OVER,
'SELECT_ACCELERATOR_TYPES_SUM_OVER_EXPECTED': SELECT_ACCELERATOR_TYPES_SUM_OVER_EXPECTED,
'SELECT_ACCELERATOR_TYPES_SIMPLE_CTE': SELECT_ACCELERATOR_TYPES_SIMPLE_CTE,
'SELECT_ACCELERATOR_TYPES_SIMPLE_CTE_EXPECTED': SELECT_ACCELERATOR_TYPES_SIMPLE_CTE_EXPECTED,
'SELECT_ANALYTICS_CACHE_GITHUB_REPOSITORIES_COLLABORATORS_EXPECTED': SELECT_ANALYTICS_CACHE_GITHUB_REPOSITORIES_COLLABORATORS_EXPECTED,
'SELECT_ANALYTICS_CACHE_GITHUB_REPOSITORIES_COLLABORATORS_SIMPLE': SELECT_ANALYTICS_CACHE_GITHUB_REPOSITORIES_COLLABORATORS_SIMPLE,
'SELECT_ANALYTICS_CACHE_GITHUB_REPOSITORIES_COLLABORATORS_TRANSPARENT': SELECT_ANALYTICS_CACHE_GITHUB_REPOSITORIES_COLLABORATORS_TRANSPARENT,
Expand Down
60 changes: 60 additions & 0 deletions test/robot/functional/stackql_mocked_from_cmd_line.robot
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,66 @@ Google AcceleratorTypes SQL verb pre changeover
... ${SELECT_ACCELERATOR_TYPES_DESC}
... ${SELECT_ACCELERATOR_TYPES_DESC_EXPECTED}

Window Function ROW_NUMBER Over AcceleratorTypes
Should StackQL Exec Inline Equal
... ${STACKQL_EXE}
... ${OKTA_SECRET_STR}
... ${GITHUB_SECRET_STR}
... ${K8S_SECRET_STR}
... ${REGISTRY_NO_VERIFY_CFG_STR}
... ${AUTH_CFG_STR}
... ${SQL_BACKEND_CFG_STR_CANONICAL}
... ${SELECT_ACCELERATOR_TYPES_ROW_NUMBER}
... ${SELECT_ACCELERATOR_TYPES_ROW_NUMBER_EXPECTED}

Window Function SUM OVER All AcceleratorTypes
Should StackQL Exec Inline Equal
... ${STACKQL_EXE}
... ${OKTA_SECRET_STR}
... ${GITHUB_SECRET_STR}
... ${K8S_SECRET_STR}
... ${REGISTRY_NO_VERIFY_CFG_STR}
... ${AUTH_CFG_STR}
... ${SQL_BACKEND_CFG_STR_CANONICAL}
... ${SELECT_ACCELERATOR_TYPES_SUM_OVER}
... ${SELECT_ACCELERATOR_TYPES_SUM_OVER_EXPECTED}

Window Function Running SUM Over AcceleratorTypes
Should StackQL Exec Inline Equal
... ${STACKQL_EXE}
... ${OKTA_SECRET_STR}
... ${GITHUB_SECRET_STR}
... ${K8S_SECRET_STR}
... ${REGISTRY_NO_VERIFY_CFG_STR}
... ${AUTH_CFG_STR}
... ${SQL_BACKEND_CFG_STR_CANONICAL}
... ${SELECT_ACCELERATOR_TYPES_RUNNING_SUM}
... ${SELECT_ACCELERATOR_TYPES_RUNNING_SUM_EXPECTED}

Window Function RANK Over AcceleratorTypes
Should StackQL Exec Inline Equal
... ${STACKQL_EXE}
... ${OKTA_SECRET_STR}
... ${GITHUB_SECRET_STR}
... ${K8S_SECRET_STR}
... ${REGISTRY_NO_VERIFY_CFG_STR}
... ${AUTH_CFG_STR}
... ${SQL_BACKEND_CFG_STR_CANONICAL}
... ${SELECT_ACCELERATOR_TYPES_RANK}
... ${SELECT_ACCELERATOR_TYPES_RANK_EXPECTED}

CTE Simple Select Over AcceleratorTypes
Should StackQL Exec Inline Equal
... ${STACKQL_EXE}
... ${OKTA_SECRET_STR}
... ${GITHUB_SECRET_STR}
... ${K8S_SECRET_STR}
... ${REGISTRY_NO_VERIFY_CFG_STR}
... ${AUTH_CFG_STR}
... ${SQL_BACKEND_CFG_STR_CANONICAL}
... ${SELECT_ACCELERATOR_TYPES_SIMPLE_CTE}
... ${SELECT_ACCELERATOR_TYPES_SIMPLE_CTE_EXPECTED}

Google Machine Types Select Paginated
Should Horrid Query StackQL Inline Equal
... ${STACKQL_EXE}
Expand Down
Loading