diff --git a/sql/analyzer/costed_index_scan.go b/sql/analyzer/costed_index_scan.go index 3bc2ec2d23..f27e6002a7 100644 --- a/sql/analyzer/costed_index_scan.go +++ b/sql/analyzer/costed_index_scan.go @@ -15,12 +15,16 @@ package analyzer import ( + "cmp" "fmt" "slices" "sort" "strings" "time" + "github.com/dolthub/vitess/go/sqltypes" + "github.com/shopspring/decimal" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/expression/function/spatial" @@ -835,6 +839,82 @@ type indexScanRangeBuilder struct { leftover []sql.Expression } +func inValsToMySQLRangeCollHelper[N cmp.Ordered](ctx *sql.Context, vals []any, typ sql.Type, precise bool) (sql.MySQLRangeCollection, bool) { + keys := make([]N, 0, len(vals)) + for _, val := range vals { + switch v := val.(type) { + case int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64: + case float32: + if precise && float32(int(v)) != v { + continue + } + case float64: + if precise && float64(int(v)) != v { + continue + } + case decimal.Decimal: + if precise && !v.Equal(decimal.NewFromInt(v.IntPart())) { + continue + } + default: + return nil, false + } + key, inRange, err := typ.Convert(ctx, val) + if err != nil { + return nil, false + } + if inRange != sql.InRange { + continue + } + keys = append(keys, key.(N)) + } + + // TODO: for integers, if len(keys) - 1 == keys[len(keys)-1] - keys[0], + // then we can just have one continuous range. unsure if this is worth it + slices.Sort(keys) + keys = slices.Compact(keys) + res := make(sql.MySQLRangeCollection, len(keys)) + for i, key := range keys { + res[i] = sql.MySQLRange{ + sql.ClosedRangeColumnExpr(key, key, typ), + } + } + + if len(res) == 0 { + return nil, true + } + return res, true +} + +// inValsToMySQLRangeColl is a fast path for in filters over numeric columns. +func inValsToMySQLRangeColl(ctx *sql.Context, vals []any, typ sql.Type) (sql.MySQLRangeCollection, bool) { + switch typ.Type() { + case sqltypes.Int8: + return inValsToMySQLRangeCollHelper[int8](ctx, vals, typ, true) + case sqltypes.Int16: + return inValsToMySQLRangeCollHelper[int16](ctx, vals, typ, true) + case sqltypes.Int32: + return inValsToMySQLRangeCollHelper[int32](ctx, vals, typ, true) + case sqltypes.Int64: + return inValsToMySQLRangeCollHelper[int64](ctx, vals, typ, true) + case sqltypes.Uint8: + return inValsToMySQLRangeCollHelper[uint8](ctx, vals, typ, true) + case sqltypes.Uint16: + return inValsToMySQLRangeCollHelper[uint16](ctx, vals, typ, true) + case sqltypes.Uint32: + return inValsToMySQLRangeCollHelper[uint32](ctx, vals, typ, true) + case sqltypes.Uint64: + return inValsToMySQLRangeCollHelper[uint64](ctx, vals, typ, true) + case sqltypes.Float32: + return inValsToMySQLRangeCollHelper[float32](ctx, vals, typ, false) + case sqltypes.Float64: + return inValsToMySQLRangeCollHelper[float64](ctx, vals, typ, false) + default: + return nil, false + } +} + // buildRangeCollection converts our representation of the best index scan // into the format that represents an index lookup, a list of sql.Range. func (b *indexScanRangeBuilder) buildRangeCollection(f indexFilter) (sql.MySQLRangeCollection, error) { @@ -848,6 +928,16 @@ func (b *indexScanRangeBuilder) buildRangeCollection(f indexFilter) (sql.MySQLRa case *iScanOr: ranges, err = b.rangeBuildOr(f, inScan) case *iScanLeaf: + // When the filter is a simple IN, we can skip costly checks like building the RangeTree. + if f.Op() == sql.IndexScanOpInSet { + cets := b.idx.ColumnExpressionTypes() + if len(cets) == 1 { + var ok bool + if ranges, ok = inValsToMySQLRangeColl(b.ctx, f.setValues, cets[0].Type); ok { + return ranges, nil + } + } + } ranges, err = b.rangeBuildLeaf(f, inScan) default: return nil, fmt.Errorf("unknown indexFilter type: %T", f) diff --git a/sql/index_builder.go b/sql/index_builder.go index 40ee878a7e..91b09b03fd 100644 --- a/sql/index_builder.go +++ b/sql/index_builder.go @@ -42,15 +42,17 @@ type MySQLIndexBuilder struct { // NewMySQLIndexBuilder returns a new MySQLIndexBuilder. Used internally to construct a range that will later be passed to // integrators through the Index function NewLookup. func NewMySQLIndexBuilder(idx Index) *MySQLIndexBuilder { - colExprTypes := make(map[string]Type) - ranges := make(map[string][]MySQLRangeColumnExpr) - for _, cet := range idx.ColumnExpressionTypes() { + cets := idx.ColumnExpressionTypes() + colExprTypes := make(map[string]Type, len(cets)) + ranges := make(map[string][]MySQLRangeColumnExpr, len(cets)) + for _, cet := range cets { typ := cet.Type if _, ok := typ.(StringType); ok { typ = typ.Promote() } - colExprTypes[strings.ToLower(cet.Expression)] = typ - ranges[strings.ToLower(cet.Expression)] = []MySQLRangeColumnExpr{AllRangeColumnExpr(typ)} + expr := strings.ToLower(cet.Expression) + colExprTypes[expr] = typ + ranges[expr] = []MySQLRangeColumnExpr{AllRangeColumnExpr(typ)} } return &MySQLIndexBuilder{ idx: idx, @@ -120,15 +122,19 @@ func (b *MySQLIndexBuilder) Equals(ctx *Context, colExpr string, keyType Type, k for i, k := range keys { // if converting from float to int results in rounding, then it's empty range if t, ok := colTyp.(NumberType); ok && t.IsNumericType() && !t.IsFloat() { - f, c := floor(k), ceil(k) - switch k.(type) { - case float32, float64: - if f != c { + switch k := k.(type) { + case float32: + if float32(int64(k)) != k { + potentialRanges[i] = EmptyRangeColumnExpr(colTyp) + continue + } + case float64: + if float64(int64(k)) != k { potentialRanges[i] = EmptyRangeColumnExpr(colTyp) continue } case decimal.Decimal: - if !f.(decimal.Decimal).Equals(c.(decimal.Decimal)) { + if !k.Equal(decimal.NewFromInt(k.IntPart())) { potentialRanges[i] = EmptyRangeColumnExpr(colTyp) continue } @@ -600,7 +606,6 @@ func (b *MySQLIndexBuilder) updateCol(ctx *Context, colExpr string, potentialRan var newRanges []MySQLRangeColumnExpr for _, currentRange := range currentRanges { for _, potentialRange := range potentialRanges { - newRange, ok, err := currentRange.TryIntersect(potentialRange) if err != nil { b.isInvalid = true