Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions sql/analyzer/costed_index_scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
package analyzer

import (
"cmp"
"fmt"
"slices"
"sort"
"strings"
"time"

"github.com/dolthub/vitess/go/sqltypes"
"github.com/shopspring/decimal"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/expression/function/spatial"
Expand Down Expand Up @@ -835,6 +839,82 @@ type indexScanRangeBuilder struct {
leftover []sql.Expression
}

func inValsToMySQLRangeCollHelper[N cmp.Ordered](ctx *sql.Context, vals []any, typ sql.Type, precise bool) (sql.MySQLRangeCollection, bool) {
keys := make([]N, 0, len(vals))
for _, val := range vals {
switch v := val.(type) {
case int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64:
case float32:
if precise && float32(int(v)) != v {
continue
}
case float64:
if precise && float64(int(v)) != v {
continue
}
case decimal.Decimal:
if precise && !v.Equal(decimal.NewFromInt(v.IntPart())) {
continue
}
default:
return nil, false
}
key, inRange, err := typ.Convert(ctx, val)
if err != nil {
return nil, false
}
if inRange != sql.InRange {
continue
}
keys = append(keys, key.(N))
}

// TODO: for integers, if len(keys) - 1 == keys[len(keys)-1] - keys[0],
// then we can just have one continuous range. unsure if this is worth it
slices.Sort(keys)
keys = slices.Compact(keys)
res := make(sql.MySQLRangeCollection, len(keys))
for i, key := range keys {
res[i] = sql.MySQLRange{
sql.ClosedRangeColumnExpr(key, key, typ),
}
}

if len(res) == 0 {
return nil, true
}
return res, true
}

// inValsToMySQLRangeColl is a fast path for in filters over numeric columns.
func inValsToMySQLRangeColl(ctx *sql.Context, vals []any, typ sql.Type) (sql.MySQLRangeCollection, bool) {
switch typ.Type() {
case sqltypes.Int8:
return inValsToMySQLRangeCollHelper[int8](ctx, vals, typ, true)
case sqltypes.Int16:
return inValsToMySQLRangeCollHelper[int16](ctx, vals, typ, true)
case sqltypes.Int32:
return inValsToMySQLRangeCollHelper[int32](ctx, vals, typ, true)
case sqltypes.Int64:
return inValsToMySQLRangeCollHelper[int64](ctx, vals, typ, true)
case sqltypes.Uint8:
return inValsToMySQLRangeCollHelper[uint8](ctx, vals, typ, true)
case sqltypes.Uint16:
return inValsToMySQLRangeCollHelper[uint16](ctx, vals, typ, true)
case sqltypes.Uint32:
return inValsToMySQLRangeCollHelper[uint32](ctx, vals, typ, true)
case sqltypes.Uint64:
return inValsToMySQLRangeCollHelper[uint64](ctx, vals, typ, true)
case sqltypes.Float32:
return inValsToMySQLRangeCollHelper[float32](ctx, vals, typ, false)
case sqltypes.Float64:
return inValsToMySQLRangeCollHelper[float64](ctx, vals, typ, false)
default:
return nil, false
}
}

// buildRangeCollection converts our representation of the best index scan
// into the format that represents an index lookup, a list of sql.Range.
func (b *indexScanRangeBuilder) buildRangeCollection(f indexFilter) (sql.MySQLRangeCollection, error) {
Expand All @@ -848,6 +928,16 @@ func (b *indexScanRangeBuilder) buildRangeCollection(f indexFilter) (sql.MySQLRa
case *iScanOr:
ranges, err = b.rangeBuildOr(f, inScan)
case *iScanLeaf:
// When the filter is a simple IN, we can skip costly checks like building the RangeTree.
if f.Op() == sql.IndexScanOpInSet {
cets := b.idx.ColumnExpressionTypes()
if len(cets) == 1 {
var ok bool
if ranges, ok = inValsToMySQLRangeColl(b.ctx, f.setValues, cets[0].Type); ok {
return ranges, nil
}
}
}
ranges, err = b.rangeBuildLeaf(f, inScan)
default:
return nil, fmt.Errorf("unknown indexFilter type: %T", f)
Expand Down
27 changes: 16 additions & 11 deletions sql/index_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,17 @@ type MySQLIndexBuilder struct {
// NewMySQLIndexBuilder returns a new MySQLIndexBuilder. Used internally to construct a range that will later be passed to
// integrators through the Index function NewLookup.
func NewMySQLIndexBuilder(idx Index) *MySQLIndexBuilder {
colExprTypes := make(map[string]Type)
ranges := make(map[string][]MySQLRangeColumnExpr)
for _, cet := range idx.ColumnExpressionTypes() {
cets := idx.ColumnExpressionTypes()
colExprTypes := make(map[string]Type, len(cets))
ranges := make(map[string][]MySQLRangeColumnExpr, len(cets))
for _, cet := range cets {
typ := cet.Type
if _, ok := typ.(StringType); ok {
typ = typ.Promote()
}
colExprTypes[strings.ToLower(cet.Expression)] = typ
ranges[strings.ToLower(cet.Expression)] = []MySQLRangeColumnExpr{AllRangeColumnExpr(typ)}
expr := strings.ToLower(cet.Expression)
colExprTypes[expr] = typ
ranges[expr] = []MySQLRangeColumnExpr{AllRangeColumnExpr(typ)}
}
return &MySQLIndexBuilder{
idx: idx,
Expand Down Expand Up @@ -120,15 +122,19 @@ func (b *MySQLIndexBuilder) Equals(ctx *Context, colExpr string, keyType Type, k
for i, k := range keys {
// if converting from float to int results in rounding, then it's empty range
if t, ok := colTyp.(NumberType); ok && t.IsNumericType() && !t.IsFloat() {
f, c := floor(k), ceil(k)
switch k.(type) {
case float32, float64:
if f != c {
switch k := k.(type) {
case float32:
if float32(int64(k)) != k {
potentialRanges[i] = EmptyRangeColumnExpr(colTyp)
continue
}
case float64:
if float64(int64(k)) != k {
potentialRanges[i] = EmptyRangeColumnExpr(colTyp)
continue
}
case decimal.Decimal:
if !f.(decimal.Decimal).Equals(c.(decimal.Decimal)) {
if !k.Equal(decimal.NewFromInt(k.IntPart())) {
potentialRanges[i] = EmptyRangeColumnExpr(colTyp)
continue
}
Expand Down Expand Up @@ -600,7 +606,6 @@ func (b *MySQLIndexBuilder) updateCol(ctx *Context, colExpr string, potentialRan
var newRanges []MySQLRangeColumnExpr
for _, currentRange := range currentRanges {
for _, potentialRange := range potentialRanges {

newRange, ok, err := currentRange.TryIntersect(potentialRange)
if err != nil {
b.isInvalid = true
Expand Down