From 4adb4cb159fc36da15669f872d7afb972024830b Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 6 Mar 2025 21:44:31 +0800 Subject: [PATCH 01/84] support late materialization for other condition Signed-off-by: gengliqi --- dbms/src/Columns/ColumnAggregateFunction.h | 8 +- dbms/src/Columns/ColumnArray.h | 8 +- dbms/src/Columns/ColumnConst.h | 5 +- dbms/src/Columns/ColumnDecimal.cpp | 14 +- dbms/src/Columns/ColumnDecimal.h | 6 +- dbms/src/Columns/ColumnFixedString.cpp | 10 +- dbms/src/Columns/ColumnFixedString.h | 3 +- dbms/src/Columns/ColumnFunction.h | 2 +- dbms/src/Columns/ColumnNullable.cpp | 16 +- dbms/src/Columns/ColumnNullable.h | 3 +- dbms/src/Columns/ColumnString.h | 10 +- dbms/src/Columns/ColumnTuple.h | 8 +- dbms/src/Columns/ColumnVector.h | 14 +- dbms/src/Columns/IColumn.h | 19 +- dbms/src/Columns/IColumnDummy.h | 5 +- dbms/src/Columns/filterColumn.cpp | 6 + .../Columns/tests/gtest_column_insertFrom.cpp | 8 + dbms/src/Interpreters/JoinUtils.cpp | 2 + dbms/src/Interpreters/JoinV2/HashJoin.cpp | 371 +++++++++++++++--- dbms/src/Interpreters/JoinV2/HashJoin.h | 4 +- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 322 +++++++++------ dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 20 +- .../Interpreters/JoinV2/HashJoinRowLayout.h | 4 + .../Interpreters/NullAwareSemiJoinHelper.cpp | 36 +- .../Interpreters/NullAwareSemiJoinHelper.h | 2 +- 25 files changed, 637 insertions(+), 269 deletions(-) diff --git a/dbms/src/Columns/ColumnAggregateFunction.h b/dbms/src/Columns/ColumnAggregateFunction.h index eb86da9d3f2..186919fc570 100644 --- a/dbms/src/Columns/ColumnAggregateFunction.h +++ b/dbms/src/Columns/ColumnAggregateFunction.h @@ -131,10 +131,12 @@ class ColumnAggregateFunction final : public COWPtrHelper= start + length); + for (size_t i = start; i < start + length; ++i) + insertFrom(src_, selective_offsets[i]); } void insertFrom(ConstAggregateDataPtr __restrict place); diff --git a/dbms/src/Columns/ColumnArray.h b/dbms/src/Columns/ColumnArray.h index e719f5a25d7..6f456fec73c 100644 --- a/dbms/src/Columns/ColumnArray.h +++ b/dbms/src/Columns/ColumnArray.h @@ -200,10 +200,12 @@ class ColumnArray final : public COWPtrHelper insertFrom(src_, n); } - void insertSelectiveFrom(const IColumn & src_, const Offsets & selective_offsets) override + void insertSelectiveRangeFrom(const IColumn & src_, const Offsets & selective_offsets, size_t start, size_t length) + override { - for (auto position : selective_offsets) - insertFrom(src_, position); + RUNTIME_CHECK(selective_offsets.size() >= start + length); + for (size_t i = start; i < start + length; ++i) + insertFrom(src_, selective_offsets[i]); } void insertDefault() override; diff --git a/dbms/src/Columns/ColumnConst.h b/dbms/src/Columns/ColumnConst.h index 6f4ef69dc54..8bb2976d376 100644 --- a/dbms/src/Columns/ColumnConst.h +++ b/dbms/src/Columns/ColumnConst.h @@ -80,10 +80,7 @@ class ColumnConst final : public COWPtrHelper void insertManyFrom(const IColumn &, size_t, size_t length) override { s += length; } - void insertSelectiveFrom(const IColumn &, const Offsets & selective_offsets) override - { - s += selective_offsets.size(); - } + void insertSelectiveRangeFrom(const IColumn &, const Offsets &, size_t, size_t length) override { s += length; } void insertMany(const Field &, size_t length) override { s += length; } diff --git a/dbms/src/Columns/ColumnDecimal.cpp b/dbms/src/Columns/ColumnDecimal.cpp index b9c031d6524..45df9eb6816 100644 --- a/dbms/src/Columns/ColumnDecimal.cpp +++ b/dbms/src/Columns/ColumnDecimal.cpp @@ -920,14 +920,18 @@ void ColumnDecimal::insertManyFrom(const IColumn & src, size_t position, size } template -void ColumnDecimal::insertSelectiveFrom(const IColumn & src, const IColumn::Offsets & selective_offsets) +void ColumnDecimal::insertSelectiveRangeFrom( + const IColumn & src, + const IColumn::Offsets & selective_offsets, + size_t start, + size_t length) { + RUNTIME_CHECK(selective_offsets.size() >= start + length); const auto & src_data = static_cast(src).data; size_t old_size = data.size(); - size_t to_add_size = selective_offsets.size(); - data.resize(old_size + to_add_size); - for (size_t i = 0; i < to_add_size; ++i) - data[i + old_size] = src_data[selective_offsets[i]]; + data.resize(old_size + length); + for (size_t i = 0; i < length; ++i) + data[i + old_size] = src_data[selective_offsets[i + start]]; } #pragma GCC diagnostic pop diff --git a/dbms/src/Columns/ColumnDecimal.h b/dbms/src/Columns/ColumnDecimal.h index 354519fa379..013719c2172 100644 --- a/dbms/src/Columns/ColumnDecimal.h +++ b/dbms/src/Columns/ColumnDecimal.h @@ -155,7 +155,11 @@ class ColumnDecimal final : public COWPtrHelper::Type>(x)); } void insertRangeFrom(const IColumn & src, size_t start, size_t length) override; void insertManyFrom(const IColumn & src_, size_t position, size_t length) override; - void insertSelectiveFrom(const IColumn & src_, const IColumn::Offsets & selective_offsets) override; + void insertSelectiveRangeFrom( + const IColumn & src_, + const IColumn::Offsets & selective_offsets, + size_t start, + size_t length) override; void popBack(size_t n) override { data.resize_assume_reserved(data.size() - n); } StringRef getRawData() const override diff --git a/dbms/src/Columns/ColumnFixedString.cpp b/dbms/src/Columns/ColumnFixedString.cpp index d02de8dfd0c..760e5aa6d26 100644 --- a/dbms/src/Columns/ColumnFixedString.cpp +++ b/dbms/src/Columns/ColumnFixedString.cpp @@ -89,16 +89,20 @@ void ColumnFixedString::insertManyFrom(const IColumn & src_, size_t position, si memcpySmallAllowReadWriteOverflow15(&chars[i], src_char_ptr, n); } -void ColumnFixedString::insertSelectiveFrom(const IColumn & src_, const Offsets & selective_offsets) +void ColumnFixedString::insertSelectiveRangeFrom( + const IColumn & src_, + const Offsets & selective_offsets, + size_t start, + size_t length) { const auto & src = static_cast(src_); if (n != src.getN()) throw Exception("Size of FixedString doesn't match", ErrorCodes::SIZE_OF_FIXED_STRING_DOESNT_MATCH); size_t old_size = chars.size(); - size_t new_size = old_size + selective_offsets.size() * n; + size_t new_size = old_size + length * n; chars.resize(new_size); const auto & src_chars = src.chars; - for (size_t i = old_size, j = 0; i < new_size; i += n, ++j) + for (size_t i = old_size, j = start; i < new_size; i += n, ++j) memcpySmallAllowReadWriteOverflow15(&chars[i], &src_chars[selective_offsets[j] * n], n); } diff --git a/dbms/src/Columns/ColumnFixedString.h b/dbms/src/Columns/ColumnFixedString.h index 7b6dd4d42b6..22aca6d551e 100644 --- a/dbms/src/Columns/ColumnFixedString.h +++ b/dbms/src/Columns/ColumnFixedString.h @@ -105,7 +105,8 @@ class ColumnFixedString final : public COWPtrHelper void insertManyFrom(const IColumn & src_, size_t position, size_t length) override; - void insertSelectiveFrom(const IColumn & src_, const Offsets & selective_offsets) override; + void insertSelectiveRangeFrom(const IColumn & src_, const Offsets & selective_offsets, size_t start, size_t length) + override; void insertData(const char * pos, size_t length) override; diff --git a/dbms/src/Columns/ColumnFunction.h b/dbms/src/Columns/ColumnFunction.h index 7813d2a0c13..d8a8a699582 100644 --- a/dbms/src/Columns/ColumnFunction.h +++ b/dbms/src/Columns/ColumnFunction.h @@ -99,7 +99,7 @@ class ColumnFunction final : public COWPtrHelper throw Exception("Cannot insert into " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void insertSelectiveFrom(const IColumn &, const Offsets &) override + void insertSelectiveRangeFrom(const IColumn &, const Offsets &, size_t, size_t) override { throw Exception("Cannot insert into " + getName(), ErrorCodes::NOT_IMPLEMENTED); } diff --git a/dbms/src/Columns/ColumnNullable.cpp b/dbms/src/Columns/ColumnNullable.cpp index c407877f465..38a23280f13 100644 --- a/dbms/src/Columns/ColumnNullable.cpp +++ b/dbms/src/Columns/ColumnNullable.cpp @@ -437,17 +437,15 @@ void ColumnNullable::insertManyFrom(const IColumn & src, size_t n, size_t length map.resize_fill(map.size() + length, src_concrete.getNullMapData()[n]); } -void ColumnNullable::insertSelectiveFrom(const IColumn & src, const Offsets & selective_offsets) +void ColumnNullable::insertSelectiveRangeFrom( + const IColumn & src, + const Offsets & selective_offsets, + size_t start, + size_t length) { const auto & src_concrete = static_cast(src); - getNestedColumn().insertSelectiveFrom(src_concrete.getNestedColumn(), selective_offsets); - auto & map = getNullMapData(); - const auto & src_map = src_concrete.getNullMapData(); - size_t old_size = map.size(); - size_t to_add_size = selective_offsets.size(); - map.resize(old_size + to_add_size); - for (size_t i = 0; i < to_add_size; ++i) - map[i + old_size] = src_map[selective_offsets[i]]; + getNestedColumn().insertSelectiveRangeFrom(src_concrete.getNestedColumn(), selective_offsets, start, length); + getNullMapColumn().insertSelectiveRangeFrom(src_concrete.getNullMapColumn(), selective_offsets, start, length); } void ColumnNullable::popBack(size_t n) diff --git a/dbms/src/Columns/ColumnNullable.h b/dbms/src/Columns/ColumnNullable.h index dd074ac8cb1..f73c488baa6 100644 --- a/dbms/src/Columns/ColumnNullable.h +++ b/dbms/src/Columns/ColumnNullable.h @@ -140,7 +140,8 @@ class ColumnNullable final : public COWPtrHelper void insert(const Field & x) override; void insertFrom(const IColumn & src, size_t n) override; void insertManyFrom(const IColumn & src, size_t n, size_t length) override; - void insertSelectiveFrom(const IColumn & src, const Offsets & selective_offsets) override; + void insertSelectiveRangeFrom(const IColumn & src, const Offsets & selective_offsets, size_t start, size_t length) + override; void insertDefault() override { diff --git a/dbms/src/Columns/ColumnString.h b/dbms/src/Columns/ColumnString.h index b7a960acb5c..494b3b0894c 100644 --- a/dbms/src/Columns/ColumnString.h +++ b/dbms/src/Columns/ColumnString.h @@ -223,12 +223,14 @@ class ColumnString final : public COWPtrHelper insertFromImpl(src, position); } - void insertSelectiveFrom(const IColumn & src_, const Offsets & selective_offsets) override + void insertSelectiveRangeFrom(const IColumn & src_, const Offsets & selective_offsets, size_t start, size_t length) + override { + RUNTIME_CHECK(selective_offsets.size() >= start + length); const auto & src = static_cast(src_); - offsets.reserve(offsets.size() + selective_offsets.size()); - for (auto position : selective_offsets) - insertFromImpl(src, position); + offsets.reserve(offsets.size() + length); + for (size_t i = start; i < start + length; ++i) + insertFromImpl(src, selective_offsets[i]); } template diff --git a/dbms/src/Columns/ColumnTuple.h b/dbms/src/Columns/ColumnTuple.h index a1a37c6a1cf..8c26a28a7cd 100644 --- a/dbms/src/Columns/ColumnTuple.h +++ b/dbms/src/Columns/ColumnTuple.h @@ -72,10 +72,12 @@ class ColumnTuple final : public COWPtrHelper insertFrom(src_, n); } - void insertSelectiveFrom(const IColumn & src_, const Offsets & selective_offsets) override + void insertSelectiveRangeFrom(const IColumn & src_, const Offsets & selective_offsets, size_t start, size_t length) + override { - for (auto position : selective_offsets) - insertFrom(src_, position); + RUNTIME_CHECK(selective_offsets.size() >= start + length); + for (size_t i = start; i < start + length; ++i) + insertFrom(src_, selective_offsets[i]); } void insertDefault() override; diff --git a/dbms/src/Columns/ColumnVector.h b/dbms/src/Columns/ColumnVector.h index a78f0539517..0e0d6584dd8 100644 --- a/dbms/src/Columns/ColumnVector.h +++ b/dbms/src/Columns/ColumnVector.h @@ -236,14 +236,18 @@ class ColumnVector final : public COWPtrHelper= start + length); const auto & src_container = static_cast(src).getData(); size_t old_size = data.size(); - size_t to_add_size = selective_offsets.size(); - data.resize(old_size + to_add_size); - for (size_t i = 0; i < to_add_size; ++i) - data[i + old_size] = src_container[selective_offsets[i]]; + data.resize(old_size + length); + for (size_t i = 0; i < length; ++i) + data[i + old_size] = src_container[selective_offsets[start + i]]; } void insertMany(const Field & field, size_t length) override diff --git a/dbms/src/Columns/IColumn.h b/dbms/src/Columns/IColumn.h index 3ab2b0d9c8b..19bd1d2ea10 100644 --- a/dbms/src/Columns/IColumn.h +++ b/dbms/src/Columns/IColumn.h @@ -156,7 +156,16 @@ class IColumn : public COWPtr /// Note: the source column and the destination column must be of the same type, can not ColumnXXX->insertSelectiveFrom(ConstColumnXXX, ...) using Offset = UInt64; using Offsets = PaddedPODArray; - virtual void insertSelectiveFrom(const IColumn & src, const Offsets & selective_offsets) = 0; + void insertSelectiveFrom(const IColumn & src, const Offsets & selective_offsets) + { + insertSelectiveRangeFrom(src, selective_offsets, 0, selective_offsets.size()); + } + virtual void insertSelectiveRangeFrom( + const IColumn & src, + const Offsets & selective_offsets, + size_t start, + size_t length) + = 0; /// Appends one field multiple times. Can be optimized in inherited classes. virtual void insertMany(const Field & field, size_t length) @@ -246,12 +255,12 @@ class IColumn : public COWPtr /// Count the serialize byte size and added to the byte_size. /// The byte_size.size() must be equal to the column size. + virtual void countSerializeByteSize(PaddedPODArray & /* byte_size */) const = 0; virtual void countSerializeByteSizeForCmp( PaddedPODArray & /* byte_size */, const NullMap * /*nullmap*/, const TiDB::TiDBCollatorPtr & /* collator */) const = 0; - virtual void countSerializeByteSize(PaddedPODArray & /* byte_size */) const = 0; /// Count the serialize byte size and added to the byte_size called by ColumnArray. /// array_offsets is the offsets of ColumnArray. @@ -331,20 +340,20 @@ class IColumn : public COWPtr /// } /// for (auto & column_ptr : mutable_columns) /// column_ptr->flushNTAlignBuffer(); + virtual void deserializeAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) = 0; virtual void deserializeForCmpAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) = 0; - virtual void deserializeAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) = 0; /// Deserialize and insert data from pos and forward each pos[i] to the end of serialized data. /// Only called by ColumnArray. /// array_offsets is the offsets of ColumnArray. /// The last pos.size() elements of array_offsets can be used to get the length of elements from each pos. - virtual void deserializeForCmpAndInsertFromPosColumnArray( + virtual void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & /* pos */, const Offsets & /* array_offsets */, bool /* use_nt_align_buffer */) = 0; - virtual void deserializeAndInsertFromPosForColumnArray( + virtual void deserializeForCmpAndInsertFromPosColumnArray( PaddedPODArray & /* pos */, const Offsets & /* array_offsets */, bool /* use_nt_align_buffer */) diff --git a/dbms/src/Columns/IColumnDummy.h b/dbms/src/Columns/IColumnDummy.h index 9d64fb6d821..8ebb9e384e4 100644 --- a/dbms/src/Columns/IColumnDummy.h +++ b/dbms/src/Columns/IColumnDummy.h @@ -217,10 +217,7 @@ class IColumnDummy : public IColumn void insertManyFrom(const IColumn &, size_t, size_t length) override { s += length; } - void insertSelectiveFrom(const IColumn &, const Offsets & selective_offsets) override - { - s += selective_offsets.size(); - } + void insertSelectiveRangeFrom(const IColumn &, const Offsets &, size_t, size_t length) override { s += length; } void insertRangeFrom(const IColumn & /*src*/, size_t /*start*/, size_t length) override { s += length; } diff --git a/dbms/src/Columns/filterColumn.cpp b/dbms/src/Columns/filterColumn.cpp index c1a0519200b..33a95159b88 100644 --- a/dbms/src/Columns/filterColumn.cpp +++ b/dbms/src/Columns/filterColumn.cpp @@ -322,6 +322,12 @@ INSTANTIATE(Decimal32, DecimalPaddedPODArray) INSTANTIATE(Decimal64, DecimalPaddedPODArray) INSTANTIATE(Decimal128, DecimalPaddedPODArray) INSTANTIATE(Decimal256, DecimalPaddedPODArray) +// Cannot use INSTANTIATE micro because `const T * data_pos` + `T: char *` will be intepreted as `const char **` +template void filterImpl>( + const UInt8 * filt_pos, + const UInt8 * filt_end, + char * const * data_pos, + PaddedPODArray & res_data); #undef INSTANTIATE diff --git a/dbms/src/Columns/tests/gtest_column_insertFrom.cpp b/dbms/src/Columns/tests/gtest_column_insertFrom.cpp index 235acf51b71..b2e095946a1 100644 --- a/dbms/src/Columns/tests/gtest_column_insertFrom.cpp +++ b/dbms/src/Columns/tests/gtest_column_insertFrom.cpp @@ -83,9 +83,17 @@ class TestColumnInsertFrom : public ::testing::Test selective_offsets.push_back(4); for (size_t position : selective_offsets) cols[0]->insertFrom(*column_ptr, position); + std::vector> range_test = {{0, 1}, {1, 2}, {0, 3}, {2, 1}, {1, 1}}; + for (auto [start, length] : range_test) + { + for (size_t i = start; i < start + length; ++i) + cols[0]->insertFrom(*column_ptr, selective_offsets[i]); + } for (size_t position : selective_offsets) cols[0]->insertFrom(*column_ptr, position); cols[1]->insertSelectiveFrom(*column_ptr, selective_offsets); + for (auto [start, length] : range_test) + cols[1]->insertSelectiveRangeFrom(*column_ptr, selective_offsets, start, length); cols[1]->insertSelectiveFrom(*column_ptr, selective_offsets); { ColumnWithTypeAndName ref(std::move(cols[0]), col_with_type_and_name.type, ""); diff --git a/dbms/src/Interpreters/JoinUtils.cpp b/dbms/src/Interpreters/JoinUtils.cpp index dfd772e5e2c..c3b02cf7bba 100644 --- a/dbms/src/Interpreters/JoinUtils.cpp +++ b/dbms/src/Interpreters/JoinUtils.cpp @@ -204,6 +204,7 @@ void mergeNullAndFilterResult( } else { + RUNTIME_CHECK(filter_column.size() == nullmap_vec->size()); if (null_as_true) { for (size_t i = 0; i < nullmap_vec->size(); ++i) @@ -224,6 +225,7 @@ void mergeNullAndFilterResult( } else { + RUNTIME_CHECK(filter_column.size() == filter_vec->size()); for (size_t i = 0; i < filter_vec->size(); ++i) filter_column[i] = filter_column[i] && (*filter_vec)[i]; } diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index c71c469a838..e768e4edf7a 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -193,7 +193,7 @@ void HashJoin::initRowLayoutAndHashJoinMethod() std::unordered_set raw_required_key_index_set; if (method != HashJoinKeyMethod::KeySerialized) { - /// Move all raw join key column to the end of the join key. + /// Move all raw required join key column to the end of the join key. Names new_key_names_left, new_key_names_right; BoolVec raw_required_key_flag(keys_size); for (size_t i = 0; i < keys_size; ++i) @@ -241,12 +241,37 @@ void HashJoin::initRowLayoutAndHashJoinMethod() key_names_right.swap(new_key_names_right); } + row_layout.other_required_count_for_other_condition = 0; size_t columns = right_sample_block_pruned.columns(); + BoolVec required_columns_flag(columns); for (size_t i = 0; i < columns; ++i) { if (raw_required_key_index_set.contains(i)) + { + required_columns_flag[i] = true; continue; - auto & c = right_sample_block_pruned.getByPosition(i); + } + auto & c = right_sample_block_pruned.safeGetByPosition(i); + if (required_columns_names_set_for_other_condition.contains(c.name)) + { + ++row_layout.other_required_count_for_other_condition; + required_columns_flag[i] = true; + if (c.column->valuesHaveFixedSize()) + { + row_layout.other_column_fixed_size += c.column->sizeOfValueIfFixed(); + row_layout.other_required_column_indexes.push_back({i, true}); + } + else + { + row_layout.other_required_column_indexes.push_back({i, false}); + } + } + } + for (size_t i = 0; i < columns; ++i) + { + if (required_columns_flag[i]) + continue; + auto & c = right_sample_block_pruned.safeGetByPosition(i); if (c.column->valuesHaveFixedSize()) { row_layout.other_column_fixed_size += c.column->sizeOfValueIfFixed(); @@ -257,6 +282,8 @@ void HashJoin::initRowLayoutAndHashJoinMethod() row_layout.other_required_column_indexes.push_back({i, false}); } } + RUNTIME_CHECK( + row_layout.raw_required_key_column_indexes.size() + row_layout.other_required_column_indexes.size() == columns); } void HashJoin::initBuild(const Block & sample_block, size_t build_concurrency_) @@ -321,6 +348,21 @@ void HashJoin::initProbe(const Block & sample_block, size_t probe_concurrency_) all_sample_block_pruned.insert(std::move(new_column)); } + if (has_other_condition) + { + left_required_flag_for_other_condition.resize(left_sample_block_pruned.columns()); + for (const auto & name : required_columns_names_set_for_other_condition) + { + RUNTIME_CHECK_MSG( + all_sample_block_pruned.has(name), + "all_sample_block_pruned should have {} in required_columns_names_set_for_other_condition", + name); + if (!left_sample_block_pruned.has(name)) + continue; + + left_required_flag_for_other_condition[left_sample_block_pruned.getPositionByName(name)] = true; + } + } probe_concurrency = probe_concurrency_; active_probe_worker = probe_concurrency; @@ -503,6 +545,7 @@ Block HashJoin::probeBlock(JoinProbeContext & context, size_t stream_index) context.prepareForHashProbe( method, kind, + non_equal_conditions.other_cond_expr != nullptr, key_names_left, non_equal_conditions.left_filter_column, probe_output_name_set, @@ -513,11 +556,6 @@ Block HashJoin::probeBlock(JoinProbeContext & context, size_t stream_index) auto & wd = probe_workers_data[stream_index]; size_t left_columns = left_sample_block_pruned.columns(); size_t right_columns = right_sample_block_pruned.columns(); - RUNTIME_CHECK_MSG( - context.block.columns() == left_columns, - "columns of block from probe side {} != columns of left_sample_block_pruned {}", - context.block.columns(), - left_columns); if (!wd.result_block) { for (size_t i = 0; i < left_columns + right_columns; ++i) @@ -528,9 +566,32 @@ Block HashJoin::probeBlock(JoinProbeContext & context, size_t stream_index) } } - MutableColumns added_columns(right_columns); - for (size_t i = 0; i < right_columns; ++i) - added_columns[i] = wd.result_block.safeGetByPosition(left_columns + i).column->assumeMutable(); + bool late_materialization = false; + if (has_other_condition) + { + late_materialization + = row_layout.other_required_count_for_other_condition < row_layout.other_required_column_indexes.size(); + } + + MutableColumns added_columns; + if (late_materialization) + { + for (auto [column_index, _] : row_layout.raw_required_key_column_indexes) + added_columns.emplace_back( + wd.result_block.safeGetByPosition(left_columns + column_index).column->assumeMutable()); + for (size_t i = 0; i < row_layout.other_required_count_for_other_condition; ++i) + { + size_t column_index = row_layout.other_required_column_indexes[i].first; + added_columns.emplace_back( + wd.result_block.safeGetByPosition(left_columns + column_index).column->assumeMutable()); + } + } + else + { + added_columns.resize(right_columns); + for (size_t i = 0; i < right_columns; ++i) + added_columns[i] = wd.result_block.safeGetByPosition(left_columns + i).column->assumeMutable(); + } Stopwatch watch; joinProbeBlock( @@ -538,6 +599,7 @@ Block HashJoin::probeBlock(JoinProbeContext & context, size_t stream_index) wd, method, kind, + late_materialization, non_equal_conditions, settings, pointer_table, @@ -548,23 +610,53 @@ Block HashJoin::probeBlock(JoinProbeContext & context, size_t stream_index) if (context.isCurrentProbeFinished()) wd.probe_handle_rows += context.rows; - for (size_t i = 0; i < right_columns; ++i) - wd.result_block.safeGetByPosition(left_columns + i).column = std::move(added_columns[i]); + if (late_materialization) + { + size_t idx = 0; + for (auto [column_index, _] : row_layout.raw_required_key_column_indexes) + wd.result_block.safeGetByPosition(left_columns + column_index).column = std::move(added_columns[idx++]); + for (size_t i = 0; i < row_layout.other_required_count_for_other_condition; ++i) + { + size_t column_index = row_layout.other_required_column_indexes[i].first; + wd.result_block.safeGetByPosition(left_columns + column_index).column = std::move(added_columns[idx++]); + } + } + else + { + for (size_t i = 0; i < right_columns; ++i) + wd.result_block.safeGetByPosition(left_columns + i).column = std::move(added_columns[i]); + } - if (!wd.selective_offsets.empty()) + if (wd.selective_offsets.empty()) + return output_block_after_finalize; + + if (has_other_condition) { + // Always using late materialization for left side for (size_t i = 0; i < left_columns; ++i) { + if (!left_required_flag_for_other_condition[i]) + continue; wd.result_block.safeGetByPosition(i).column->assumeMutable()->insertSelectiveFrom( *context.block.safeGetByPosition(i).column.get(), wd.selective_offsets); } } + else + { + for (size_t i = 0; i < left_columns; ++i) + { + wd.result_block.safeGetByPosition(i).column->assumeMutable()->insertSelectiveFrom( + *context.block.safeGetByPosition(i).column.get(), + wd.selective_offsets); + } + } + wd.replicate_time += watch.elapsedFromLastTime(); if (has_other_condition) { - auto res_block = handleOtherConditions(stream_index); + auto res_block = handleOtherConditions(context, wd, late_materialization); wd.other_condition_time += watch.elapsedFromLastTime(); return res_block; } @@ -684,6 +776,34 @@ void HashJoin::finalize(const Names & parent_require) output_columns_names_set_for_other_condition_after_finalize.insert(name); if (!match_helper_name.empty()) output_columns_names_set_for_other_condition_after_finalize.insert(match_helper_name); + + if (non_equal_conditions.other_cond_expr != nullptr) + { + const auto & actions = non_equal_conditions.other_cond_expr->getActions(); + for (const auto & action : actions) + { + Names needed_columns = action.getNeededColumns(); + for (const auto & name : needed_columns) + { + if (output_columns_names_set_for_other_condition_after_finalize.contains(name)) + required_columns_names_set_for_other_condition.insert(name); + } + } + } + + if (non_equal_conditions.null_aware_eq_cond_expr != nullptr) + { + const auto & actions = non_equal_conditions.null_aware_eq_cond_expr->getActions(); + for (const auto & action : actions) + { + Names needed_columns = action.getNeededColumns(); + for (const auto & name : needed_columns) + { + if (output_columns_names_set_for_other_condition_after_finalize.contains(name)) + required_columns_names_set_for_other_condition.insert(name); + } + } + } } /// remove duplicated column @@ -706,28 +826,50 @@ void HashJoin::finalize(const Names & parent_require) finalized = true; } -Block HashJoin::handleOtherConditions(size_t stream_index) +Block HashJoin::handleOtherConditions(JoinProbeContext & context, JoinProbeWorkerData & wd, bool late_materialization) { - auto & wd = probe_workers_data[stream_index]; - non_equal_conditions.other_cond_expr->execute(wd.result_block); + size_t left_columns = left_sample_block_pruned.columns(); + size_t right_columns = right_sample_block_pruned.columns(); + // Some columns in wd.result_block may be empty so need to create another block to execute other condition expressions + Block exec_block; + for (size_t i = 0; i < left_columns; ++i) + { + if (left_required_flag_for_other_condition[i]) + exec_block.insert(wd.result_block.getByPosition(i)); + } + if (late_materialization) + { + for (auto [column_index, _] : row_layout.raw_required_key_column_indexes) + exec_block.insert(wd.result_block.getByPosition(left_columns + column_index)); + for (size_t i = 0; i < row_layout.other_required_count_for_other_condition; ++i) + { + size_t column_index = row_layout.other_required_column_indexes[i].first; + exec_block.insert(wd.result_block.getByPosition(left_columns + column_index)); + } + } + else + { + for (size_t i = 0; i < right_columns; ++i) + exec_block.insert(wd.result_block.getByPosition(left_columns + i)); + } + + non_equal_conditions.other_cond_expr->execute(exec_block); - size_t rows = wd.result_block.rows(); + size_t rows = exec_block.rows(); RUNTIME_CHECK_MSG( rows <= settings.max_block_size, - "result_block rows {} > max_block_size {}", + "exec_block rows {} > max_block_size {}", rows, settings.max_block_size); wd.filter.clear(); - wd.filter.reserve(rows); - mergeNullAndFilterResult(wd.result_block, wd.filter, non_equal_conditions.other_cond_name, false); - RUNTIME_CHECK(wd.filter.size() == rows); + mergeNullAndFilterResult(exec_block, wd.filter, non_equal_conditions.other_cond_name, false); - size_t columns = output_block_after_finalize.columns(); + size_t output_columns = output_block_after_finalize.columns(); auto init_result_block_for_other_condition = [&]() { wd.result_block_for_other_condition = {}; - for (size_t i = 0; i < columns; ++i) + for (size_t i = 0; i < output_columns; ++i) { ColumnWithTypeAndName new_column = output_block_after_finalize.safeGetByPosition(i).cloneEmpty(); new_column.column->assumeMutable()->reserveAlign(settings.max_block_size, FULL_VECTOR_SIZE_AVX2); @@ -744,54 +886,173 @@ Block HashJoin::handleOtherConditions(size_t stream_index) wd.result_block_for_other_condition.rows(), settings.max_block_size); size_t remaining_insert_size = settings.max_block_size - wd.result_block_for_other_condition.rows(); - size_t result_size = countBytesInFilter(wd.filter); - wd.filter_offsets1.clear(); - wd.filter_offsets1.reserve(result_size); - filterImpl(&wd.filter[0], &wd.filter[rows], &base_offsets[0], wd.filter_offsets1); - RUNTIME_CHECK(wd.filter_offsets1.size() == result_size); - if (result_size > remaining_insert_size) - { - wd.filter_offsets2.clear(); - wd.filter_offsets2.resize(result_size - remaining_insert_size); - memcpy( - &wd.filter_offsets2[0], - &wd.filter_offsets1[remaining_insert_size], - sizeof(IColumn::Offset) * (result_size - remaining_insert_size)); - wd.filter_offsets1.resize(remaining_insert_size); - } + bool filter_offsets_is_initialized = false; + auto init_filter_offsets = [&]() { + RUNTIME_CHECK(wd.filter.size() == rows); + wd.filter_offsets.clear(); + wd.filter_offsets.reserve(result_size); + filterImpl(&wd.filter[0], &wd.filter[rows], &base_offsets[0], wd.filter_offsets); + RUNTIME_CHECK(wd.filter_offsets.size() == result_size); + filter_offsets_is_initialized = true; + }; - for (size_t i = 0; i < columns; ++i) - { - auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(i); - auto & src_column = wd.result_block.getByName(des_column.name); - des_column.column->assumeMutable()->insertSelectiveFrom(*src_column.column.get(), wd.filter_offsets1); - } + bool filter_selective_offsets_is_initialized = false; + auto init_filter_selective_offsets = [&]() { + RUNTIME_CHECK(wd.selective_offsets.size() == rows); + wd.filter_selective_offsets.clear(); + wd.filter_selective_offsets.reserve(result_size); + filterImpl(&wd.filter[0], &wd.filter[rows], &wd.selective_offsets[0], wd.filter_selective_offsets); + RUNTIME_CHECK(wd.filter_selective_offsets.size() == result_size); + filter_selective_offsets_is_initialized = true; + }; + + bool filter_row_ptrs_for_lm_is_initialized = false; + auto init_filter_row_ptrs_for_lm = [&]() { + RUNTIME_CHECK(wd.row_ptrs_for_lm.size() == rows); + wd.filter_row_ptrs_for_lm.clear(); + wd.filter_row_ptrs_for_lm.reserve(result_size); + filterImpl(&wd.filter[0], &wd.filter[rows], &wd.row_ptrs_for_lm[0], wd.filter_row_ptrs_for_lm); + RUNTIME_CHECK(wd.filter_row_ptrs_for_lm.size() == result_size); + filter_row_ptrs_for_lm_is_initialized = true; + }; + + auto fill_block = [&](size_t start, size_t length) { + if (late_materialization) + { + for (auto [column_index, _] : row_layout.raw_required_key_column_indexes) + { + const auto & name = right_sample_block_pruned.getByPosition(column_index).name; + if (!wd.result_block_for_other_condition.has(name)) + continue; + if unlikely (!filter_offsets_is_initialized) + init_filter_offsets(); + auto & des_column = wd.result_block_for_other_condition.getByName(name); + auto & src_column = exec_block.getByName(name); + des_column.column->assumeMutable() + ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); + } + for (size_t i = 0; i < row_layout.other_required_count_for_other_condition; ++i) + { + size_t column_index = row_layout.other_required_column_indexes[i].first; + const auto & name = right_sample_block_pruned.getByPosition(column_index).name; + if (!wd.result_block_for_other_condition.has(name)) + continue; + if unlikely (!filter_offsets_is_initialized) + init_filter_offsets(); + auto & des_column = wd.result_block_for_other_condition.getByName(name); + auto & src_column = exec_block.getByName(name); + des_column.column->assumeMutable() + ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); + } + if (!filter_row_ptrs_for_lm_is_initialized) + init_filter_row_ptrs_for_lm(); + + std::vector actual_column_indexes; + for (size_t i = row_layout.other_required_count_for_other_condition; + i < row_layout.other_required_column_indexes.size(); + ++i) + { + size_t column_index = row_layout.other_required_column_indexes[i].first; + const auto & name = right_sample_block_pruned.getByPosition(column_index).name; + size_t actual_column_index = wd.result_block_for_other_condition.getPositionByName(name); + actual_column_indexes.emplace_back(actual_column_index); + } - Block res_block = output_block_after_finalize; + constexpr size_t step = 256; + for (size_t i = start; i < start + length; i += step) + { + size_t end = i + step > start + length ? start + length : i + step; + wd.insert_batch.clear(); + wd.insert_batch.insert(&wd.row_ptrs_for_lm[i], &wd.row_ptrs_for_lm[end]); + for (auto column_index : actual_column_indexes) + { + auto & des_column = wd.result_block_for_other_condition.getByPosition(column_index); + des_column.column->assumeMutable()->deserializeAndInsertFromPos(wd.insert_batch, true); + } + } + for (auto column_index : actual_column_indexes) + { + auto & des_column = wd.result_block_for_other_condition.getByPosition(column_index); + des_column.column->assumeMutable()->flushNTAlignBuffer(); + } + } + else + { + for (size_t i = 0; i < right_columns; ++i) + { + const auto & name = right_sample_block_pruned.getByPosition(i).name; + if (!wd.result_block_for_other_condition.has(name)) + continue; + if unlikely (!filter_offsets_is_initialized) + init_filter_offsets(); + auto & des_column = wd.result_block_for_other_condition.getByName(name); + auto & src_column = exec_block.getByName(name); + des_column.column->assumeMutable() + ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); + } + } + + for (size_t i = 0; i < left_columns; ++i) + { + const auto & name = left_sample_block_pruned.getByPosition(i).name; + if (!wd.result_block_for_other_condition.has(name)) + continue; + auto & des_column = wd.result_block_for_other_condition.getByName(name); + if (left_required_flag_for_other_condition[i]) + { + if unlikely (!filter_offsets_is_initialized && !filter_selective_offsets_is_initialized) + init_filter_selective_offsets(); + if (filter_offsets_is_initialized) + { + auto & src_column = exec_block.getByName(name); + des_column.column->assumeMutable() + ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); + } + else + { + auto & src_column = context.block.safeGetByPosition(i); + des_column.column->assumeMutable()->insertSelectiveRangeFrom( + *src_column.column.get(), + wd.filter_selective_offsets, + start, + length); + } + continue; + } + if unlikely (!filter_selective_offsets_is_initialized) + init_filter_selective_offsets(); + auto & src_column = context.block.safeGetByPosition(i); + des_column.column->assumeMutable() + ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_selective_offsets, start, length); + } + }; + + size_t length = result_size > remaining_insert_size ? remaining_insert_size : result_size; + fill_block(0, length); + + Block res_block; if (result_size >= remaining_insert_size) { res_block = wd.result_block_for_other_condition; init_result_block_for_other_condition(); if (result_size > remaining_insert_size) - { - for (size_t i = 0; i < columns; ++i) - { - auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(i); - auto & src_column = wd.result_block.getByName(des_column.name); - des_column.column->assumeMutable()->insertSelectiveFrom(*src_column.column.get(), wd.filter_offsets2); - } - } + fill_block(remaining_insert_size, result_size - remaining_insert_size); + } + else + { + res_block = output_block_after_finalize; } + exec_block.clear(); /// Remove the new column added from other condition expressions. removeUselessColumn(wd.result_block); assertBlocksHaveEqualStructure( wd.result_block, all_sample_block_pruned, - "Join Probe reuse result_block for other condition"); + "Join Probe reuses result_block for other condition"); /// Clear the data in result_block. for (size_t i = 0; i < wd.result_block.columns(); ++i) diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.h b/dbms/src/Interpreters/JoinV2/HashJoin.h index 8efbe6ff659..105e7d137b9 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.h +++ b/dbms/src/Interpreters/JoinV2/HashJoin.h @@ -78,7 +78,7 @@ class HashJoin void workAfterBuildRowFinish(); - Block handleOtherConditions(size_t stream_index); + Block handleOtherConditions(JoinProbeContext & context, JoinProbeWorkerData & wd, bool late_materialization); private: const ASTTableJoin::Kind kind; @@ -125,6 +125,7 @@ class HashJoin NameSet output_column_names_set_after_finalize; NameSet output_columns_names_set_for_other_condition_after_finalize; Names required_columns; + NameSet required_columns_names_set_for_other_condition; bool finalized = false; /// Row containers @@ -146,6 +147,7 @@ class HashJoin /// For other condition const IColumn::Offsets base_offsets; + BoolVec left_required_flag_for_other_condition; }; using HashJoinPtr = std::shared_ptr; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 06a7fa543ff..5f2f5277a45 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -13,12 +13,16 @@ // limitations under the License. #include +#include #include #include #include #include #include +#include "Columns/IColumn.h" +#include "Parsers/ASTTablesInSelectQuery.h" + #ifdef TIFLASH_ENABLE_AVX_SUPPORT ASSERT_USE_AVX2_COMPILE_FLAG #endif @@ -30,7 +34,7 @@ using enum ASTTableJoin::Kind; bool JoinProbeContext::isCurrentProbeFinished() const { - return start_row_idx >= rows && prefetch_active_states == 0; + return start_row_idx >= rows && prefetch_active_states == 0 && rows_is_matched.empty(); } void JoinProbeContext::resetBlock(Block & block_) @@ -39,7 +43,10 @@ void JoinProbeContext::resetBlock(Block & block_) orignal_block = block_; rows = block.rows(); start_row_idx = 0; - current_probe_row_ptr = nullptr; + current_row_ptr = nullptr; + current_row_is_matched = false; + rows_is_matched.clear(); + prefetch_active_states = 0; is_prepared = false; @@ -47,12 +54,12 @@ void JoinProbeContext::resetBlock(Block & block_) key_columns.clear(); null_map = nullptr; null_map_holder = nullptr; - current_row_is_matched = false; } void JoinProbeContext::prepareForHashProbe( HashJoinKeyMethod method, ASTTableJoin::Kind kind, + bool has_other_condition, const Names & key_names, const String & filter_column, const NameSet & probe_output_name_set, @@ -95,12 +102,18 @@ void JoinProbeContext::prepareForHashProbe( assertBlocksHaveEqualStructure(block, sample_block_pruned, "Join Probe"); + if ((kind == LeftOuter || kind == Semi || kind == Anti) && has_other_condition) + { + rows_is_matched.clear(); + rows_is_matched.resize_fill_zero(block.rows()); + } + is_prepared = true; } #define PREFETCH_READ(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */) -/// The implemtation of prefetch in join probe process is inspired by a paper named +/// The implemtation of prefetching in join probe process is inspired by a paper named /// `Asynchronous Memory Access Chaining` in vldb-15. /// Ref: https://www.vldb.org/pvldb/vol9/p252-kocberber.pdf enum class ProbePrefetchStage : UInt8 @@ -130,8 +143,7 @@ struct ProbePrefetchState KeyType key{}; }; - -template +template class JoinProbeBlockHelper { public: @@ -164,9 +176,13 @@ class JoinProbeBlockHelper { wd.insert_batch.clear(); wd.insert_batch.reserve(settings.probe_insert_batch_size); - wd.selective_offsets.clear(); wd.selective_offsets.reserve(settings.max_block_size); + if constexpr (late_materialization) + { + wd.row_ptrs_for_lm.clear(); + wd.row_ptrs_for_lm.reserve(settings.max_block_size); + } if (pointer_table.enableProbePrefetch() && !context.prefetch_states) { @@ -174,10 +190,25 @@ class JoinProbeBlockHelper static_cast(new ProbePrefetchState[settings.probe_prefetch_step]), [](void * ptr) { delete[] static_cast *>(ptr); }); } + + if constexpr (late_materialization) + { + RUNTIME_CHECK( + added_columns.size() + == row_layout.raw_required_key_column_indexes.size() + + row_layout.other_required_count_for_other_condition); + } + else + { + RUNTIME_CHECK( + added_columns.size() + == row_layout.raw_required_key_column_indexes.size() + row_layout.other_required_column_indexes.size()); + } } void joinProbeBlockImpl(); +private: void NO_INLINE joinProbeBlockInner(); void NO_INLINE joinProbeBlockInnerPrefetch(); @@ -190,22 +221,15 @@ class JoinProbeBlockHelper void NO_INLINE joinProbeBlockAnti(); void NO_INLINE joinProbeBlockAntiPrefetch(); - template void NO_INLINE joinProbeBlockRightOuter(); - template void NO_INLINE joinProbeBlockRightOuterPrefetch(); - template void NO_INLINE joinProbeBlockRightSemi(); - template void NO_INLINE joinProbeBlockRightSemiPrefetch(); - template void NO_INLINE joinProbeBlockRightAnti(); - template void NO_INLINE joinProbeBlockRightAntiPrefetch(); -private: bool ALWAYS_INLINE joinKeyIsEqual( KeyGetterType & key_getter, const KeyType & key1, @@ -236,33 +260,69 @@ class JoinProbeBlockHelper if likely (wd.insert_batch.size() < settings.probe_insert_batch_size) return; } - for (auto [column_index, is_nullable] : row_layout.raw_required_key_column_indexes) + if constexpr (late_materialization) { - IColumn * column = added_columns[column_index].get(); - if (has_null_map && is_nullable) - column = &static_cast(*added_columns[column_index]).getNestedColumn(); - column->deserializeAndInsertFromPos(wd.insert_batch, true); - } - for (auto [column_index, _] : row_layout.other_required_column_indexes) - added_columns[column_index]->deserializeAndInsertFromPos(wd.insert_batch, true); + size_t idx = 0; + for (auto [_, is_nullable] : row_layout.raw_required_key_column_indexes) + { + IColumn * column = added_columns[idx].get(); + if (has_null_map && is_nullable) + column = &static_cast(*added_columns[idx]).getNestedColumn(); + column->deserializeAndInsertFromPos(wd.insert_batch, true); + ++idx; + } + for (size_t i = 0; i < row_layout.other_required_count_for_other_condition; ++i) + added_columns[idx++]->deserializeAndInsertFromPos(wd.insert_batch, true); - if constexpr (force) + wd.row_ptrs_for_lm.insert(wd.insert_batch.begin(), wd.insert_batch.end()); + } + else { for (auto [column_index, is_nullable] : row_layout.raw_required_key_column_indexes) { IColumn * column = added_columns[column_index].get(); if (has_null_map && is_nullable) column = &static_cast(*added_columns[column_index]).getNestedColumn(); - column->flushNTAlignBuffer(); + column->deserializeAndInsertFromPos(wd.insert_batch, true); } for (auto [column_index, _] : row_layout.other_required_column_indexes) - added_columns[column_index]->flushNTAlignBuffer(); + added_columns[column_index]->deserializeAndInsertFromPos(wd.insert_batch, true); + } + + if constexpr (force) + { + if constexpr (late_materialization) + { + size_t idx = 0; + for (auto [_, is_nullable] : row_layout.raw_required_key_column_indexes) + { + IColumn * column = added_columns[idx].get(); + if (has_null_map && is_nullable) + column = &static_cast(*added_columns[idx]).getNestedColumn(); + column->flushNTAlignBuffer(); + ++idx; + } + for (size_t i = 0; i < row_layout.other_required_count_for_other_condition; ++i) + added_columns[idx++]->flushNTAlignBuffer(); + } + else + { + for (auto [column_index, is_nullable] : row_layout.raw_required_key_column_indexes) + { + IColumn * column = added_columns[column_index].get(); + if (has_null_map && is_nullable) + column = &static_cast(*added_columns[column_index]).getNestedColumn(); + column->flushNTAlignBuffer(); + } + for (auto [column_index, _] : row_layout.other_required_column_indexes) + added_columns[column_index]->flushNTAlignBuffer(); + } } wd.insert_batch.clear(); } - void ALWAYS_INLINE fillNullMap(size_t size) const + void ALWAYS_INLINE fillNullMapWithZero(size_t size) const { if constexpr (has_null_map) { @@ -288,17 +348,50 @@ class JoinProbeBlockHelper const HashJoinPointerTable & pointer_table; const HashJoinRowLayout & row_layout; MutableColumns & added_columns; - const size_t added_rows; + size_t added_rows; }; -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockInner() +template +void JoinProbeBlockHelper::joinProbeBlockImpl() +{ +#define CALL(JoinType) \ + { \ + if (pointer_table.enableProbePrefetch()) \ + joinProbeBlock##JoinType##Prefetch(); \ + else \ + joinProbeBlock##JoinType(); \ + } + + if (kind == Inner) + CALL(Inner) + else if (kind == LeftOuter) + CALL(LeftOuter) + else if (kind == Semi) + CALL(Semi) + else if (kind == Anti) + CALL(Anti) + else if (kind == RightOuter) + CALL(RightOuter) + else if (kind == RightSemi) + CALL(RightSemi) + else if (kind == RightAnti) + CALL(RightAnti) + else + throw Exception("Logical error: unknown combination of JOIN", ErrorCodes::LOGICAL_ERROR); + +#undef CALL2 +#undef CALL +} + +template +void NO_INLINE +JoinProbeBlockHelper::joinProbeBlockInner() { auto & key_getter = *static_cast(context.key_getter.get()); size_t current_offset = added_rows; auto & selective_offsets = wd.selective_offsets; size_t idx = context.start_row_idx; - RowPtr ptr = context.current_probe_row_ptr; + RowPtr ptr = context.current_row_ptr; size_t collision = 0; size_t key_offset = sizeof(RowPtr); if constexpr (KeyGetterType::joinKeyCompareHashFirst()) @@ -356,15 +449,16 @@ void NO_INLINE JoinProbeBlockHelper::jo } } flushBatchIfNecessary(); - fillNullMap(current_offset - added_rows); + fillNullMapWithZero(current_offset - added_rows); context.start_row_idx = idx; - context.current_probe_row_ptr = ptr; + context.current_row_ptr = ptr; wd.collision += collision; } -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockInnerPrefetch() +template +void NO_INLINE +JoinProbeBlockHelper::joinProbeBlockInnerPrefetch() { auto & key_getter = *static_cast(context.key_getter.get()); auto * states = static_cast *>(context.prefetch_states.get()); @@ -487,7 +581,7 @@ void NO_INLINE JoinProbeBlockHelper::jo } flushBatchIfNecessary(); - fillNullMap(current_offset - added_rows); + fillNullMapWithZero(current_offset - added_rows); context.start_row_idx = idx; context.prefetch_active_states = active_states; @@ -495,108 +589,70 @@ void NO_INLINE JoinProbeBlockHelper::jo wd.collision += collision; } -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockLeftOuter() +template +void NO_INLINE +JoinProbeBlockHelper::joinProbeBlockLeftOuter() {} -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockLeftOuterPrefetch() +template +void NO_INLINE +JoinProbeBlockHelper::joinProbeBlockLeftOuterPrefetch() {} -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockSemi() +template +void NO_INLINE JoinProbeBlockHelper::joinProbeBlockSemi() {} -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockSemiPrefetch() +template +void NO_INLINE +JoinProbeBlockHelper::joinProbeBlockSemiPrefetch() {} -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockAnti() +template +void NO_INLINE JoinProbeBlockHelper::joinProbeBlockAnti() {} -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockAntiPrefetch() +template +void NO_INLINE +JoinProbeBlockHelper::joinProbeBlockAntiPrefetch() {} -template -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockRightOuter() +template +void NO_INLINE +JoinProbeBlockHelper::joinProbeBlockRightOuter() {} -template -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockRightOuterPrefetch() +template +void NO_INLINE +JoinProbeBlockHelper::joinProbeBlockRightOuterPrefetch() {} -template -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockRightSemi() +template +void NO_INLINE +JoinProbeBlockHelper::joinProbeBlockRightSemi() {} -template -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockRightSemiPrefetch() +template +void NO_INLINE +JoinProbeBlockHelper::joinProbeBlockRightSemiPrefetch() {} -template -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockRightAnti() +template +void NO_INLINE +JoinProbeBlockHelper::joinProbeBlockRightAnti() {} -template -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockRightAntiPrefetch() +template +void NO_INLINE +JoinProbeBlockHelper::joinProbeBlockRightAntiPrefetch() {} -template -void JoinProbeBlockHelper::joinProbeBlockImpl() -{ -#define CALL(JoinType) \ - if (pointer_table.enableProbePrefetch()) \ - joinProbeBlock##JoinType##Prefetch(); \ - else \ - joinProbeBlock##JoinType(); - -#define CALL2(JoinType, has_other_condition) \ - if (pointer_table.enableProbePrefetch()) \ - joinProbeBlock##JoinType##Prefetch(); \ - else \ - joinProbeBlock##JoinType(); - - bool has_other_condition = non_equal_conditions.other_cond_expr != nullptr; - if (kind == Inner) - CALL(Inner) - else if (kind == LeftOuter) - CALL(LeftOuter) - else if (kind == Semi && !has_other_condition) - CALL(Semi) - else if (kind == Anti && !has_other_condition) - CALL(Anti) - else if (kind == RightOuter && has_other_condition) - CALL2(RightOuter, true) - else if (kind == RightOuter) - CALL2(RightOuter, false) - else if (kind == RightSemi && has_other_condition) - CALL2(RightSemi, true) - else if (kind == RightSemi) - CALL2(RightSemi, false) - else if (kind == RightAnti && has_other_condition) - CALL2(RightAnti, true) - else if (kind == RightAnti) - CALL2(RightAnti, false) - else - throw Exception("Logical error: unknown combination of JOIN", ErrorCodes::LOGICAL_ERROR); - -#undef CALL2 -#undef CALL -} - void joinProbeBlock( JoinProbeContext & context, JoinProbeWorkerData & wd, HashJoinKeyMethod method, ASTTableJoin::Kind kind, + bool late_materialization, const JoinNonEqualConditions & non_equal_conditions, const HashJoinSettings & settings, const HashJoinPointerTable & pointer_table, @@ -609,28 +665,38 @@ void joinProbeBlock( switch (method) { -#define CALL(KeyGetter, has_null_map, tagged_pointer) \ - JoinProbeBlockHelper( \ - context, \ - wd, \ - method, \ - kind, \ - non_equal_conditions, \ - settings, \ - pointer_table, \ - row_layout, \ - added_columns, \ - added_rows) \ +#define CALL(KeyGetter, has_null_map, tagged_pointer, late_materialization) \ + JoinProbeBlockHelper( \ + context, \ + wd, \ + method, \ + kind, \ + non_equal_conditions, \ + settings, \ + pointer_table, \ + row_layout, \ + added_columns, \ + added_rows) \ .joinProbeBlockImpl(); -#define CALL2(KeyGetter, has_null_map) \ - if (pointer_table.enableTaggedPointer()) \ - { \ - CALL(KeyGetter, has_null_map, true); \ - } \ - else \ - { \ - CALL(KeyGetter, has_null_map, false); \ +#define CALL3(KeyGetter, has_null_map, tagged_pointer) \ + if (late_materialization) \ + { \ + CALL(KeyGetter, has_null_map, tagged_pointer, true); \ + } \ + else \ + { \ + CALL(KeyGetter, has_null_map, tagged_pointer, false); \ + } + +#define CALL2(KeyGetter, has_null_map) \ + if (pointer_table.enableTaggedPointer()) \ + { \ + CALL3(KeyGetter, has_null_map, true); \ + } \ + else \ + { \ + CALL3(KeyGetter, has_null_map, false); \ } #define CALL1(KeyGetter) \ diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index fa9ac0be5ea..22e8a7c0296 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -16,7 +16,9 @@ #include #include +#include #include +#include #include #include #include @@ -35,7 +37,11 @@ struct JoinProbeContext Block orignal_block; size_t rows = 0; size_t start_row_idx = 0; - RowPtr current_probe_row_ptr = nullptr; + RowPtr current_row_ptr = nullptr; + /// For left outer/(left outer) (anti) semi join without other conditions. + bool current_row_is_matched = false; + /// For left outer/(left outer) (anti) semi join with other conditions. + IColumn::Filter rows_is_matched; size_t prefetch_active_states = 0; size_t prefetch_iter = 0; @@ -48,8 +54,6 @@ struct JoinProbeContext ConstNullMapPtr null_map = nullptr; std::unique_ptr> key_getter; - bool current_row_is_matched = false; - bool input_is_finished = false; bool isCurrentProbeFinished() const; @@ -58,6 +62,7 @@ struct JoinProbeContext void prepareForHashProbe( HashJoinKeyMethod method, ASTTableJoin::Kind kind, + bool has_other_condition, const Names & key_names, const String & filter_column, const NameSet & probe_output_name_set, @@ -69,6 +74,7 @@ struct JoinProbeContext struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData { IColumn::Offsets selective_offsets; + RowPtrs row_ptrs_for_lm; RowPtrs insert_batch; @@ -79,10 +85,11 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData size_t other_condition_time = 0; size_t collision = 0; - /// filter for other condition + /// For other condition ColumnVector::Container filter; - IColumn::Offsets filter_offsets1; - IColumn::Offsets filter_offsets2; + IColumn::Offsets filter_offsets; + IColumn::Offsets filter_selective_offsets; + RowPtrs filter_row_ptrs_for_lm; /// Schema: HashJoin::all_sample_block_pruned Block result_block; @@ -95,6 +102,7 @@ void joinProbeBlock( JoinProbeWorkerData & wd, HashJoinKeyMethod method, ASTTableJoin::Kind kind, + bool late_materialization, const JoinNonEqualConditions & non_equal_conditions, const HashJoinSettings & settings, const HashJoinPointerTable & pointer_table, diff --git a/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h b/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h index a82dac38533..72c88c1aaff 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h @@ -84,6 +84,10 @@ struct HashJoinRowLayout std::vector> raw_required_key_column_indexes; /// other_required_column_index + is_fixed_size std::vector> other_required_column_indexes; + /// Number of columns at the beginning of `other_required_column_indexes` + /// that are used for evaluating the join other condition. + size_t other_required_count_for_other_condition = 0; + size_t key_column_fixed_size = 0; size_t other_column_fixed_size = 0; diff --git a/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp b/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp index c7c43051fda..7bec95bd4e9 100644 --- a/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp +++ b/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp @@ -19,6 +19,8 @@ #include #include "Functions/FunctionBinaryArithmetic.h" +#include "Interpreters/SemiJoinHelper.h" +#include "Parsers/ASTTablesInSelectQuery.h" namespace DB { @@ -291,35 +293,17 @@ Block NASemiJoinHelper::genJoinResult(const NameSet & ou auto result = join_result[i].getResult(); if constexpr (KIND == ASTTableJoin::Kind::NullAware_Anti) { - if (result == SemiJoinResultType::TRUE_VALUE) - { - // If the result is true, this row should be kept. - (*filter)[i] = 1; - ++rows_for_anti; - } - else - { - // If the result is null or false, this row should be filtered. - (*filter)[i] = 0; - } + // If the result is true, this row should be kept. + // Otherwise, this row should be filtered. + (*filter)[i] = result == SemiJoinResultType::TRUE_VALUE ? 1 : 0; + rows_for_anti += (*filter)[i]; } else { - switch (result) - { - case SemiJoinResultType::FALSE_VALUE: - left_semi_column_data->push_back(0); - left_semi_null_map->push_back(0); - break; - case SemiJoinResultType::TRUE_VALUE: - left_semi_column_data->push_back(1); - left_semi_null_map->push_back(0); - break; - case SemiJoinResultType::NULL_VALUE: - left_semi_column_data->push_back(0); - left_semi_null_map->push_back(1); - break; - } + Int8 res = result == SemiJoinResultType::TRUE_VALUE ? 1 : 0; + UInt8 is_null = result == SemiJoinResultType::NULL_VALUE ? 1 : 0; + left_semi_column_data->push_back(res); + left_semi_null_map->push_back(is_null); } } diff --git a/dbms/src/Interpreters/NullAwareSemiJoinHelper.h b/dbms/src/Interpreters/NullAwareSemiJoinHelper.h index 906ccc80b58..530a68df800 100644 --- a/dbms/src/Interpreters/NullAwareSemiJoinHelper.h +++ b/dbms/src/Interpreters/NullAwareSemiJoinHelper.h @@ -33,7 +33,7 @@ enum class NASemiJoinStep : UInt8 /// with at least one null join key. /// The join keys of this left row must not have null. NOT_NULL_KEY_CHECK_NULL_ROWS, - /// Like `CHECK_NULL_ROWS_NOT_NULL` except the join keys of this left row must have null. + /// Like `NOT_NULL_KEY_CHECK_NULL_ROWS` except the join keys of this left row must have null. NULL_KEY_CHECK_NULL_ROWS, /// Check join key equal condition and other conditions(if any) for all right rows in blocks. /// The join keys of this left row must have null. From cff7cccd926b7906fb5a69d90bf72367d00b6053 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Sat, 22 Mar 2025 14:36:34 +0800 Subject: [PATCH 02/84] u Signed-off-by: gengliqi --- .gitmodules | 3 - cmake/sanitize.cmake | 2 +- .../0003-Suppress-enum-overflow.patch | 25 + contrib/aws-cmake/CMakeLists.txt | 47 +- dbms/CMakeLists.txt | 6 - dbms/cmake/find_vectorclass.cmake | 29 - .../AggregateFunctionMinMaxAny.h | 8 + .../AggregateFunctionMinMaxWindow.h | 6 + dbms/src/Client/Connection.cpp | 71 +- dbms/src/Client/Connection.h | 10 +- dbms/src/Client/MultiplexedConnections.cpp | 29 +- dbms/src/Client/MultiplexedConnections.h | 6 +- dbms/src/Columns/ColumnDecimal.cpp | 6 +- dbms/src/Columns/ColumnString.cpp | 10 +- dbms/src/Columns/ColumnString.h | 13 +- dbms/src/Columns/FilterDescription.h | 2 +- dbms/src/Columns/IColumn.h | 6 +- dbms/src/Columns/VirtualColumnUtils.cpp | 6 +- dbms/src/Columns/VirtualColumnUtils.h | 2 +- dbms/src/Columns/filterColumn.cpp | 122 +- .../src/Columns/tests/bench_column_filter.cpp | 93 +- .../tests/bench_column_string_filter.cpp | 261 ++++ .../gtest_column_serialize_deserialize.cpp | 14 +- dbms/src/Common/Arena.h | 26 +- dbms/src/Common/ColumnsHashing.h | 430 +++--- dbms/src/Common/ColumnsHashingImpl.h | 12 + dbms/src/Common/ExternalTable.h | 250 ---- dbms/src/Common/FailPoint.cpp | 7 +- dbms/src/Common/TiFlashMetrics.h | 16 + dbms/src/Common/config.h.in | 1 - dbms/src/Common/config_build.cpp.in | 2 - dbms/src/Core/Block.cpp | 2 + .../AddingDefaultBlockOutputStream.cpp | 25 +- .../AddingDefaultBlockOutputStream.h | 14 +- .../src/DataStreams/FilterTransformAction.cpp | 10 +- .../MergeSortingBlockInputStream.cpp | 5 + .../PushingToViewsBlockOutputStream.cpp | 61 - .../PushingToViewsBlockOutputStream.h | 71 - .../src/DataStreams/WindowTransformAction.cpp | 6 +- dbms/src/DataStreams/WindowTransformAction.h | 6 +- .../src/Debug/MockKVStore/MockProxyRegion.cpp | 4 +- .../Debug/MockKVStore/MockRaftStoreProxy.cpp | 14 +- .../Debug/MockKVStore/MockRaftStoreProxy.h | 1 + dbms/src/Debug/MockKVStore/MockReadIndex.cpp | 2 +- dbms/src/Debug/MockKVStore/MockSSTReader.h | 3 +- dbms/src/Debug/MockKVStore/MockTiKV.cpp | 69 + dbms/src/Debug/MockKVStore/MockTiKV.h | 32 +- dbms/src/Debug/MockKVStore/MockUtils.cpp | 897 +++++++++++++ dbms/src/Debug/MockKVStore/MockUtils.h | 237 ++-- dbms/src/Debug/ReadIndexStressTest.cpp | 1 + .../dbgKVStore/dbgFuncMockRaftCommand.cpp | 8 +- .../dbgKVStore/dbgFuncMockRaftSnapshot.cpp | 37 +- dbms/src/Debug/dbgKVStore/dbgFuncRegion.cpp | 5 +- dbms/src/Debug/dbgKVStore/dbgKVStore.h | 3 +- dbms/src/Debug/dbgKVStore/dbgRegion.h | 4 +- dbms/src/Debug/dbgNaturalDag.cpp | 19 +- dbms/src/Debug/dbgNaturalDag.h | 1 - dbms/src/Debug/dbgQueryExecutor.cpp | 1 + dbms/src/Debug/dbgTools.cpp | 761 +---------- dbms/src/Debug/dbgTools.h | 96 +- .../ComplexKeyCacheDictionary.cpp | 393 ------ .../Dictionaries/ComplexKeyCacheDictionary.h | 689 ---------- ...acheDictionary_createAttributeWithType.cpp | 95 -- .../ComplexKeyCacheDictionary_generate1.cpp | 55 - .../ComplexKeyCacheDictionary_generate2.cpp | 54 - .../ComplexKeyCacheDictionary_generate3.cpp | 54 - ...exKeyCacheDictionary_setAttributeValue.cpp | 82 -- ...cheDictionary_setDefaultAttributeValue.cpp | 75 -- .../Flash/Coprocessor/TablesRegionsInfo.cpp | 1 + .../Mpp/tests/gtest_mpp_task_manager.cpp | 3 +- .../Flash/Planner/Plans/PhysicalJoinV2.cpp | 2 +- .../Planner/tests/gtest_physical_plan.cpp | 1 - .../tests/gtest_aggregation_executor.cpp | 23 +- dbms/src/Flash/tests/gtest_spill_sort.cpp | 55 + dbms/src/Functions/CMakeLists.txt | 4 - dbms/src/Functions/FunctionsMath.h | 73 +- dbms/src/Interpreters/Aggregator.cpp | 826 ++++++------ dbms/src/Interpreters/Aggregator.h | 595 ++------- dbms/src/Interpreters/Context.cpp | 165 +-- dbms/src/Interpreters/Context.h | 58 +- dbms/src/Interpreters/DictionaryFactory.cpp | 30 +- dbms/src/Interpreters/ExpressionAnalyzer.cpp | 276 +--- dbms/src/Interpreters/ExpressionAnalyzer.h | 30 +- .../src/Interpreters/InterpreterDropQuery.cpp | 30 +- .../Interpreters/InterpreterExistsQuery.cpp | 3 +- .../Interpreters/InterpreterInsertQuery.cpp | 17 +- .../Interpreters/InterpreterSelectQuery.cpp | 32 +- dbms/src/Interpreters/JoinV2/HashJoin.cpp | 489 ++----- dbms/src/Interpreters/JoinV2/HashJoin.h | 20 +- .../src/Interpreters/JoinV2/HashJoinBuild.cpp | 51 +- dbms/src/Interpreters/JoinV2/HashJoinBuild.h | 8 +- dbms/src/Interpreters/JoinV2/HashJoinKey.cpp | 2 +- .../JoinV2/HashJoinPointerTable.cpp | 2 +- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 1164 +++++++++++------ dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 226 +++- .../Interpreters/JoinV2/HashJoinRowLayout.h | 118 +- .../Interpreters/JoinV2/HashJoinSettings.h | 14 +- .../LogicalExpressionsOptimizer.cpp | 36 +- .../LogicalExpressionsOptimizer.h | 2 +- .../Interpreters/NullAwareSemiJoinHelper.cpp | 5 +- dbms/src/Interpreters/SemiJoinHelper.cpp | 21 +- .../tests/gtest_interpreter_create_query.cpp | 10 +- .../Operators/HashJoinV2ProbeTransformOp.cpp | 6 +- dbms/src/Operators/MergeSortTransformOp.cpp | 5 + dbms/src/Parsers/ASTInsertQuery.h | 5 +- dbms/src/Parsers/ASTSelectQuery.cpp | 8 - dbms/src/Parsers/ASTSelectQuery.h | 3 +- dbms/src/Parsers/ExpressionElementParsers.cpp | 6 +- dbms/src/Parsers/ParserSelectQuery.cpp | 10 - dbms/src/Server/Client.cpp | 117 +- dbms/src/Server/DTTool/DTTool.cpp | 2 +- dbms/src/Server/DTTool/DTToolBench.cpp | 2 +- dbms/src/Server/DTTool/DTToolInspect.cpp | 2 +- dbms/src/Server/DTTool/DTToolMigrate.cpp | 2 +- dbms/src/Server/Server.cpp | 2 +- dbms/src/Server/TCPHandler.cpp | 19 +- .../DeltaMerge/BitmapFilter/BitmapFilter.h | 3 +- dbms/src/Storages/DeltaMerge/CMakeLists.txt | 2 + .../ColumnFile/ColumnFileInputStream.cpp | 70 + .../ColumnFile/ColumnFileInputStream.h | 76 ++ .../ColumnFile/ColumnFileInputStream_fwd.h | 18 +- .../ColumnFile/ColumnFileSetReader.h | 2 +- ...olumnFileSetWithVectorIndexInputStream.cpp | 234 ---- .../ColumnFileSetWithVectorIndexInputStream.h | 108 -- .../DeltaMerge/ColumnFile/ColumnFileTiny.cpp | 17 + .../DeltaMerge/ColumnFile/ColumnFileTiny.h | 57 +- .../ColumnFileTinyLocalIndexWriter.cpp | 84 +- .../ColumnFileTinyVectorIndexReader.cpp | 142 -- .../ColumnFileTinyVectorIndexReader.h | 94 -- .../ConcatSkippableBlockInputStream.cpp | 128 +- .../ConcatSkippableBlockInputStream.h | 66 +- .../ConcatSkippableBlockInputStream_fwd.h} | 22 +- .../Decode/SSTFilesToBlockInputStream.h | 3 +- .../Decode/SSTFilesToDTFilesOutputStream.h | 3 +- .../DeltaMerge/Delta/ColumnFilePersistedSet.h | 2 +- .../DeltaMerge/Delta/DeltaValueSpace.cpp | 17 +- .../Storages/DeltaMerge/DeltaMergeStore.cpp | 46 +- .../src/Storages/DeltaMerge/DeltaMergeStore.h | 2 + .../DeltaMergeStore_InternalSegment.cpp | 37 +- .../Storages/DeltaMerge/File/ColumnCache.cpp | 90 +- .../Storages/DeltaMerge/File/ColumnCache.h | 21 +- dbms/src/Storages/DeltaMerge/File/DMFile.h | 7 +- .../File/DMFileBlockInputStream.cpp | 77 +- .../DeltaMerge/File/DMFileBlockInputStream.h | 31 +- .../File/DMFileLocalIndexWriter.cpp | 66 +- .../DeltaMerge/File/DMFileLocalIndexWriter.h | 2 + .../Storages/DeltaMerge/File/DMFileReader.cpp | 35 +- .../Storages/DeltaMerge/File/DMFileReader.h | 4 +- .../File/DMFileVectorIndexReader.cpp | 226 ---- .../DeltaMerge/File/DMFileVectorIndexReader.h | 90 -- .../DMFileWithVectorIndexBlockInputStream.cpp | 221 ---- .../DMFileWithVectorIndexBlockInputStream.h | 139 -- .../Index/InvertedIndex/CommonUtil.cpp | 189 +++ .../Index/InvertedIndex/CommonUtil.h | 99 ++ .../DeltaMerge/Index/InvertedIndex/Reader.cpp | 261 ++++ .../DeltaMerge/Index/InvertedIndex/Reader.h | 99 ++ .../Perf.cpp => InvertedIndex/Reader_fwd.h} | 13 +- .../DeltaMerge/Index/InvertedIndex/Writer.cpp | 280 ++++ .../DeltaMerge/Index/InvertedIndex/Writer.h | 105 ++ .../DeltaMerge/Index/LocalIndexInfo.cpp | 8 +- .../DeltaMerge/Index/LocalIndexInfo.h | 18 + .../DeltaMerge/Index/LocalIndexWriter.cpp | 74 ++ .../DeltaMerge/Index/LocalIndexWriter.h | 89 ++ .../DeltaMerge/Index/LocalIndexWriter_fwd.h | 31 + .../DeltaMerge/Index/VectorIndex/Perf.h | 57 +- .../VectorIndex/Perf_fwd.h} | 4 +- .../DeltaMerge/Index/VectorIndex/Reader.cpp | 98 +- .../DeltaMerge/Index/VectorIndex/Reader.h | 38 +- .../Stream/ColumnFileInputStream.cpp | 125 ++ .../Stream/ColumnFileInputStream.h | 73 ++ .../Stream/ColumnFileInputStream_fwd.h | 26 + .../Index/VectorIndex/Stream/Ctx.cpp | 106 ++ .../DeltaMerge/Index/VectorIndex/Stream/Ctx.h | 82 ++ .../Index/VectorIndex/Stream/Ctx_fwd.h | 28 + .../VectorIndex/Stream/DMFileInputStream.cpp | 212 +++ .../VectorIndex/Stream/DMFileInputStream.h | 88 ++ .../Stream/DMFileInputStream_fwd.h | 26 + .../VectorIndex/Stream/IProvideVectorIndex.h | 60 + .../Index/VectorIndex/Stream/InputStream.cpp | 230 ++++ .../Index/VectorIndex/Stream/InputStream.h | 78 ++ .../VectorIndex/Stream/InputStream_fwd.h | 26 + .../Stream/ReaderFromColumnFileTiny.cpp | 85 ++ .../Stream/ReaderFromColumnFileTiny.h | 31 + .../VectorIndex/Stream/ReaderFromDMFile.cpp | 150 +++ .../VectorIndex/Stream/ReaderFromDMFile.h | 30 + .../DeltaMerge/Index/VectorIndex/Writer.cpp | 59 +- .../DeltaMerge/Index/VectorIndex/Writer.h | 62 +- dbms/src/Storages/DeltaMerge/Segment.cpp | 154 +-- dbms/src/Storages/DeltaMerge/Segment.h | 3 +- .../Storages/DeltaMerge/SegmentReadTask.cpp | 2 +- .../DeltaMerge/SkippableBlockInputStream.h | 40 + .../Storages/DeltaMerge/StableValueSpace.cpp | 102 +- .../Storages/DeltaMerge/StableValueSpace.h | 24 +- .../DeltaMerge/VectorIndexBlockInputStream.h | 55 - .../DeltaMerge/VersionChain/ColumnView.h | 158 +++ .../Storages/DeltaMerge/dtpb/index_file.proto | 6 + .../src/Storages/DeltaMerge/tests/DMTestEnv.h | 47 +- .../DeltaMerge/tests/gtest_column_cache.cpp | 121 ++ .../tests/gtest_dm_delta_merge_store.cpp | 22 +- ...est_dm_delta_merge_store_fast_add_peer.cpp | 14 +- ...test_dm_delta_merge_store_vector_index.cpp | 273 ++-- .../tests/gtest_dm_vector_index.cpp | 845 +++++++----- .../tests/gtest_dm_vector_index_utils.h | 44 +- .../DeltaMerge/tests/gtest_inverted_index.cpp | 302 +++++ .../DeltaMerge/tests/gtest_segment_bitmap.cpp | 436 +++--- .../tests/gtest_segment_test_basic.cpp | 244 +++- .../tests/gtest_segment_test_basic.h | 46 +- .../DeltaMerge/tests/gtest_segment_util.cpp | 142 +- .../DeltaMerge/tests/gtest_segment_util.h | 35 +- .../gtest_skippable_block_input_stream.cpp | 23 +- .../tests/gtest_sst_files_stream.cpp | 9 +- .../Storages/DeltaMerge/workload/Options.cpp | 3 +- dbms/src/Storages/IManageableStorage.h | 2 +- dbms/src/Storages/IStorage.h | 24 +- dbms/src/Storages/KVStore/BackgroundService.h | 6 - .../KVStore/Decode/PartitionStreams.cpp | 76 +- .../KVStore/Decode/PartitionStreams.h | 3 +- .../Storages/KVStore/Decode/RegionTable.cpp | 12 - .../src/Storages/KVStore/Decode/RegionTable.h | 70 +- dbms/src/Storages/KVStore/FFI/ProxyFFI.cpp | 1 + dbms/src/Storages/KVStore/KVStore.h | 5 +- .../KVStore/MultiRaft/ApplySnapshot.cpp | 13 + .../KVStore/MultiRaft/ApplySnapshot.h | 65 + .../MultiRaft/Disagg/CheckpointIngestInfo.cpp | 1 + .../MultiRaft/Disagg/CheckpointIngestInfo.h | 3 +- .../KVStore/MultiRaft/Disagg/FastAddPeer.h | 4 +- .../MultiRaft/Disagg/FastAddPeerContext.h | 3 +- .../KVStore/MultiRaft/PrehandleSnapshot.cpp | 1 + .../Storages/KVStore/MultiRaft/RegionData.h | 1 - .../KVStore/MultiRaft/RegionExecutionResult.h | 5 +- .../Storages/KVStore/MultiRaft/RegionMeta.h | 4 +- .../KVStore/MultiRaft/RegionPersister.h | 4 +- .../KVStore/MultiRaft/RegionsRangeIndex.h | 5 +- .../Spill/RegionUncommittedDataList.h | 3 +- dbms/src/Storages/KVStore/ProxyStateMachine.h | 1 + dbms/src/Storages/KVStore/Read/LearnerRead.h | 1 - .../KVStore/Read/LearnerReadWorker.cpp | 8 +- dbms/src/Storages/KVStore/Region.cpp | 15 +- dbms/src/Storages/KVStore/Region.h | 9 +- dbms/src/Storages/KVStore/Region_fwd.h | 29 + dbms/src/Storages/KVStore/TMTContext.cpp | 1 + .../Storages/KVStore/tests/gtest_kvstore.cpp | 98 +- .../tests/gtest_kvstore_fast_add_peer.cpp | 3 +- .../KVStore/tests/gtest_learner_read.cpp | 23 +- .../Storages/KVStore/tests/gtest_memory.cpp | 195 ++- .../KVStore/tests/gtest_new_kvstore.cpp | 58 +- .../tests/gtest_proxy_state_machine.cpp | 2 + .../tests/gtest_region_block_reader.cpp | 5 +- .../KVStore/tests/gtest_region_persister.cpp | 38 +- .../Storages/KVStore/tests/gtest_spill.cpp | 1 + .../KVStore/tests/gtest_sync_schema.cpp | 2 +- .../KVStore/tests/gtest_sync_status.cpp | 5 +- .../KVStore/tests/gtest_tikv_keyvalue.cpp | 2 +- .../Storages/KVStore/tests/kvstore_helper.h | 8 +- dbms/src/Storages/Page/V3/PageDirectory.cpp | 19 +- .../V3/Universal/UniversalPageIdFormatImpl.h | 8 +- .../Page/tools/PageCtl/PageStorageCtlV3.cpp | 2 +- .../Storages/Page/workload/PSStressEnv.cpp | 3 +- dbms/src/Storages/S3/FileCache.cpp | 4 + dbms/src/Storages/S3/FileCache.h | 10 +- .../src/Storages/S3/tests/gtest_filecache.cpp | 192 +-- dbms/src/Storages/SelectQueryInfo.h | 4 +- .../Storages/System/StorageSystemTables.cpp | 25 - dbms/src/TiDB/Schema/InvertedIndex.h | 59 + dbms/src/TiDB/Schema/TiDB.h | 1 + libs/libcommon/CMakeLists.txt | 1 - libs/libcommon/include/common/MultiVersion.h | 67 - .../include/common/iostream_debug_helpers.h | 210 --- libs/libcommon/src/tests/CMakeLists.txt | 17 - libs/libcommon/src/tests/date_lut2.cpp | 72 - libs/libcommon/src/tests/date_lut3.cpp | 88 -- libs/libcommon/src/tests/date_lut4.cpp | 35 - .../src/tests/date_lut_default_timezone.cpp | 46 - libs/libcommon/src/tests/multi_version.cpp | 71 - tests/README.md | 4 +- tests/docker/cluster.yaml | 6 +- tests/docker/cluster_tidb_fail_point.yaml | 6 +- tests/fullstack-test/expr/cast_as_json.test | 4 +- .../expr/duration_pushdown.test | 2 +- tests/fullstack-test/mpp/window_agg.test | 78 ++ tests/sanitize/asan.suppression | 2 + .../sanitize/asan_ignores.txt | 0 tests/sanitize/tsan.suppression | 5 + 283 files changed, 10185 insertions(+), 10097 deletions(-) create mode 100644 contrib/aws-cmake/0003-Suppress-enum-overflow.patch delete mode 100644 dbms/cmake/find_vectorclass.cmake create mode 100644 dbms/src/Columns/tests/bench_column_string_filter.cpp delete mode 100644 dbms/src/Common/ExternalTable.h delete mode 100644 dbms/src/DataStreams/PushingToViewsBlockOutputStream.cpp delete mode 100644 dbms/src/DataStreams/PushingToViewsBlockOutputStream.h create mode 100644 dbms/src/Debug/MockKVStore/MockTiKV.cpp create mode 100644 dbms/src/Debug/MockKVStore/MockUtils.cpp delete mode 100644 dbms/src/Dictionaries/ComplexKeyCacheDictionary.cpp delete mode 100644 dbms/src/Dictionaries/ComplexKeyCacheDictionary.h delete mode 100644 dbms/src/Dictionaries/ComplexKeyCacheDictionary_createAttributeWithType.cpp delete mode 100644 dbms/src/Dictionaries/ComplexKeyCacheDictionary_generate1.cpp delete mode 100644 dbms/src/Dictionaries/ComplexKeyCacheDictionary_generate2.cpp delete mode 100644 dbms/src/Dictionaries/ComplexKeyCacheDictionary_generate3.cpp delete mode 100644 dbms/src/Dictionaries/ComplexKeyCacheDictionary_setAttributeValue.cpp delete mode 100644 dbms/src/Dictionaries/ComplexKeyCacheDictionary_setDefaultAttributeValue.cpp create mode 100644 dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileInputStream.cpp create mode 100644 dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileInputStream.h rename libs/libcommon/src/tests/date_lut_init.cpp => dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileInputStream_fwd.h (71%) delete mode 100644 dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileSetWithVectorIndexInputStream.cpp delete mode 100644 dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileSetWithVectorIndexInputStream.h delete mode 100644 dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTinyVectorIndexReader.cpp delete mode 100644 dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTinyVectorIndexReader.h rename dbms/src/Storages/{KVStore/tests/region_helper.h => DeltaMerge/ConcatSkippableBlockInputStream_fwd.h} (61%) delete mode 100644 dbms/src/Storages/DeltaMerge/File/DMFileVectorIndexReader.cpp delete mode 100644 dbms/src/Storages/DeltaMerge/File/DMFileVectorIndexReader.h delete mode 100644 dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.cpp delete mode 100644 dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/InvertedIndex/CommonUtil.cpp create mode 100644 dbms/src/Storages/DeltaMerge/Index/InvertedIndex/CommonUtil.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Reader.cpp create mode 100644 dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Reader.h rename dbms/src/Storages/DeltaMerge/Index/{VectorIndex/Perf.cpp => InvertedIndex/Reader_fwd.h} (73%) create mode 100644 dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Writer.cpp create mode 100644 dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Writer.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/LocalIndexWriter.cpp create mode 100644 dbms/src/Storages/DeltaMerge/Index/LocalIndexWriter.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/LocalIndexWriter_fwd.h rename dbms/src/Storages/DeltaMerge/{File/DMFileWithVectorIndexBlockInputStream_fwd.h => Index/VectorIndex/Perf_fwd.h} (81%) create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ColumnFileInputStream.cpp create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ColumnFileInputStream.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ColumnFileInputStream_fwd.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/Ctx.cpp create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/Ctx.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/Ctx_fwd.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/DMFileInputStream.cpp create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/DMFileInputStream.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/DMFileInputStream_fwd.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/IProvideVectorIndex.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/InputStream.cpp create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/InputStream.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/InputStream_fwd.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromColumnFileTiny.cpp create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromColumnFileTiny.h create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromDMFile.cpp create mode 100644 dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromDMFile.h delete mode 100644 dbms/src/Storages/DeltaMerge/VectorIndexBlockInputStream.h create mode 100644 dbms/src/Storages/DeltaMerge/VersionChain/ColumnView.h create mode 100644 dbms/src/Storages/DeltaMerge/tests/gtest_column_cache.cpp create mode 100644 dbms/src/Storages/DeltaMerge/tests/gtest_inverted_index.cpp create mode 100644 dbms/src/Storages/KVStore/MultiRaft/ApplySnapshot.h create mode 100644 dbms/src/Storages/KVStore/Region_fwd.h create mode 100644 dbms/src/TiDB/Schema/InvertedIndex.h delete mode 100644 libs/libcommon/include/common/MultiVersion.h delete mode 100644 libs/libcommon/include/common/iostream_debug_helpers.h delete mode 100644 libs/libcommon/src/tests/date_lut2.cpp delete mode 100644 libs/libcommon/src/tests/date_lut3.cpp delete mode 100644 libs/libcommon/src/tests/date_lut4.cpp delete mode 100644 libs/libcommon/src/tests/date_lut_default_timezone.cpp delete mode 100644 libs/libcommon/src/tests/multi_version.cpp rename asan_ignores.txt => tests/sanitize/asan_ignores.txt (100%) diff --git a/.gitmodules b/.gitmodules index 5f5aa5e778e..6539152492c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -38,9 +38,6 @@ [submodule "contrib/prometheus-cpp"] path = contrib/prometheus-cpp url = https://github.com/jupp0r/prometheus-cpp.git -[submodule "contrib/junction"] - path = contrib/junction - url = https://github.com/preshing/junction.git [submodule "contrib/jemalloc"] path = contrib/jemalloc url = https://github.com/jemalloc/jemalloc.git diff --git a/cmake/sanitize.cmake b/cmake/sanitize.cmake index d7a1edbc34d..70e4ff0e1ab 100644 --- a/cmake/sanitize.cmake +++ b/cmake/sanitize.cmake @@ -19,7 +19,7 @@ else () set (SAN_FLAGS "${SAN_FLAGS} -O3") endif () -set (SANITIZE_BLACKLIST_FILE "${TiFlash_SOURCE_DIR}/asan_ignores.txt") +set (SANITIZE_BLACKLIST_FILE "${TiFlash_SOURCE_DIR}/tests/sanitize/asan_ignores.txt") set (CMAKE_SANITIZE_BLACKLIST_FILE_FLAG "-fsanitize-blacklist=${SANITIZE_BLACKLIST_FILE}") set (CMAKE_CXX_FLAGS_ASAN "${CMAKE_CXX_FLAGS_ASAN} ${SAN_FLAGS} -fsanitize=address ${CMAKE_SANITIZE_BLACKLIST_FILE_FLAG}") set (CMAKE_C_FLAGS_ASAN "${CMAKE_C_FLAGS_ASAN} ${SAN_FLAGS} -fsanitize=address ${CMAKE_SANITIZE_BLACKLIST_FILE_FLAG}") diff --git a/contrib/aws-cmake/0003-Suppress-enum-overflow.patch b/contrib/aws-cmake/0003-Suppress-enum-overflow.patch new file mode 100644 index 00000000000..b28c686333e --- /dev/null +++ b/contrib/aws-cmake/0003-Suppress-enum-overflow.patch @@ -0,0 +1,25 @@ +From ab4511229c3cf5c0bd666f4e1c61e16c2fff1bed Mon Sep 17 00:00:00 2001 +From: JaySon-Huang +Date: Mon, 3 Mar 2025 14:49:51 +0800 +Subject: [PATCH] Suppress enum overflow + +Signed-off-by: JaySon-Huang +--- + .../source/utils/EnumParseOverflowContainer.cpp | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/src/aws-cpp-sdk-core/source/utils/EnumParseOverflowContainer.cpp b/src/aws-cpp-sdk-core/source/utils/EnumParseOverflowContainer.cpp +index eaeba1d9105..4eb02c05d25 100644 +--- a/src/aws-cpp-sdk-core/source/utils/EnumParseOverflowContainer.cpp ++++ b/src/aws-cpp-sdk-core/source/utils/EnumParseOverflowContainer.cpp +@@ -28,6 +28,6 @@ const Aws::String& EnumParseOverflowContainer::RetrieveOverflow(int hashCode) co + void EnumParseOverflowContainer::StoreOverflow(int hashCode, const Aws::String& value) + { + WriterLockGuard guard(m_overflowLock); +- AWS_LOGSTREAM_WARN(LOG_TAG, "Encountered enum member " << value << " which is not modeled in your clients. You should update your clients when you get a chance."); ++ AWS_LOGSTREAM_DEBUG(LOG_TAG, "Encountered enum member " << value << " which is not modeled in your clients. You should update your clients when you get a chance."); + m_overflowMap[hashCode] = value; + } +-- +2.43.5 + diff --git a/contrib/aws-cmake/CMakeLists.txt b/contrib/aws-cmake/CMakeLists.txt index 681a50e24b9..d9d4f289be8 100644 --- a/contrib/aws-cmake/CMakeLists.txt +++ b/contrib/aws-cmake/CMakeLists.txt @@ -92,14 +92,14 @@ file(GLOB AWS_SDK_CORE_SRC "${AWS_SDK_CORE_DIR}/source/utils/xml/*.cpp" ) +## Notice: update the patch using `git format-patch` if you upgrade aws!!! +set (AWS_PATCH_FILE "${TiFlash_SOURCE_DIR}/contrib/aws-cmake/0001-More-reliable-way-to-check-if-there-is-anything-in-r.patch") execute_process( COMMAND grep "GetResponseBody().peek()" "${AWS_SDK_CORE_DIR}/source/client/AWSXmlClient.cpp" RESULT_VARIABLE HAVE_APPLY_PATCH) # grep - Normally, the exit status is 0 if selected lines are found and 1 otherwise. But the exit status is 2 if an error occurred. if (HAVE_APPLY_PATCH EQUAL 1) - message(STATUS "aws patch not apply: ${HAVE_APPLY_PATCH}, patching...") - ## update the patch using `git format-patch` if you upgrade aws - set (AWS_PATCH_FILE "${TiFlash_SOURCE_DIR}/contrib/aws-cmake/0001-More-reliable-way-to-check-if-there-is-anything-in-r.patch") + message(STATUS "aws patch ${AWS_PATCH_FILE} not apply: ${HAVE_APPLY_PATCH}, patching...") # apply the patch execute_process( COMMAND git apply -v "${AWS_PATCH_FILE}" @@ -109,8 +109,43 @@ if (HAVE_APPLY_PATCH EQUAL 1) if (NOT PATCH_SUCC EQUAL 0) message(FATAL_ERROR "Can not apply aws patch ${AWS_PATCH_FILE}") endif () +elseif (HAVE_APPLY_PATCH EQUAL 0) + message(STATUS "aws patch have been applied, patch=${HAVE_APPLY_PATCH}") +else () + message(FATAL_ERROR "Can not check the aws patch status") +endif () - set (AWS_PATCH_FILE "${TiFlash_SOURCE_DIR}/contrib/aws-cmake/0002-Reduce-verbose-error-logging-and-404-for-HEAD-reques.patch") +set (AWS_PATCH_FILE "${TiFlash_SOURCE_DIR}/contrib/aws-cmake/0002-Reduce-verbose-error-logging-and-404-for-HEAD-reques.patch") +execute_process( + COMMAND grep "// ignore error logging for HEAD request with 404 response code" "${AWS_SDK_CORE_DIR}/source/client/AWSXmlClient.cpp" + RESULT_VARIABLE HAVE_APPLY_PATCH) +# grep - Normally, the exit status is 0 if selected lines are found and 1 otherwise. But the exit status is 2 if an error occurred. +if (HAVE_APPLY_PATCH EQUAL 1) + message(STATUS "aws patch ${AWS_PATCH_FILE} not apply: ${HAVE_APPLY_PATCH}, patching...") + # apply the patch + execute_process( + COMMAND git apply -v "${AWS_PATCH_FILE}" + WORKING_DIRECTORY "${AWS_SDK_DIR}" + COMMAND_ECHO STDOUT + RESULT_VARIABLE PATCH_SUCC) + if (NOT PATCH_SUCC EQUAL 0) + message(FATAL_ERROR "Can not apply aws patch ${AWS_PATCH_FILE}") + else () + message(STATUS "aws patch done, patch=${AWS_PATCH_FILE}") + endif () +elseif (HAVE_APPLY_PATCH EQUAL 0) + message(STATUS "aws patch have been applied, patch=${HAVE_APPLY_PATCH}") +else () + message(FATAL_ERROR "Can not check the aws patch status") +endif () + +set (AWS_PATCH_FILE "${TiFlash_SOURCE_DIR}/contrib/aws-cmake/0003-Suppress-enum-overflow.patch") +execute_process( + COMMAND grep "AWS_LOGSTREAM_DEBUG.*You should update your clients when you get a chance" "${AWS_SDK_CORE_DIR}/source/utils/EnumParseOverflowContainer.cpp" + RESULT_VARIABLE HAVE_APPLY_PATCH) +# grep - Normally, the exit status is 0 if selected lines are found and 1 otherwise. But the exit status is 2 if an error occurred. +if (HAVE_APPLY_PATCH EQUAL 1) + message(STATUS "aws patch ${AWS_PATCH_FILE} not apply: ${HAVE_APPLY_PATCH}, patching...") # apply the patch execute_process( COMMAND git apply -v "${AWS_PATCH_FILE}" @@ -120,10 +155,10 @@ if (HAVE_APPLY_PATCH EQUAL 1) if (NOT PATCH_SUCC EQUAL 0) message(FATAL_ERROR "Can not apply aws patch ${AWS_PATCH_FILE}") else () - message(STATUS "aws patch done") + message(STATUS "aws patch done, patch=${AWS_PATCH_FILE}") endif () elseif (HAVE_APPLY_PATCH EQUAL 0) - message(STATUS "aws patch have been applied: ${HAVE_APPLY_PATCH}") + message(STATUS "aws patch have been applied, patch=${HAVE_APPLY_PATCH}") else () message(FATAL_ERROR "Can not check the aws patch status") endif () diff --git a/dbms/CMakeLists.txt b/dbms/CMakeLists.txt index fdc5a70fbc5..70939c96436 100644 --- a/dbms/CMakeLists.txt +++ b/dbms/CMakeLists.txt @@ -16,8 +16,6 @@ if (USE_INCLUDE_WHAT_YOU_USE) set (CMAKE_CXX_INCLUDE_WHAT_YOU_USE ${IWYU_PATH}) endif () -include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/find_vectorclass.cmake) - set (CONFIG_VERSION ${CMAKE_CURRENT_BINARY_DIR}/src/Common/config_version.h) set (CONFIG_COMMON ${CMAKE_CURRENT_BINARY_DIR}/src/Common/config.h) set (CONFIG_BUILD ${CMAKE_CURRENT_BINARY_DIR}/src/Common/config_build.cpp) @@ -172,10 +170,6 @@ if (CMAKE_BUILD_TYPE_UC STREQUAL "RELEASE" OR CMAKE_BUILD_TYPE_UC STREQUAL "RELW set_source_files_properties( src/Dictionaries/CacheDictionary.cpp src/Dictionaries/TrieDictionary.cpp - src/Dictionaries/ComplexKeyCacheDictionary.cpp - src/Dictionaries/ComplexKeyCacheDictionary_generate1.cpp - src/Dictionaries/ComplexKeyCacheDictionary_generate2.cpp - src/Dictionaries/ComplexKeyCacheDictionary_generate3.cpp src/Dictionaries/HTTPDictionarySource.cpp src/Dictionaries/LibraryDictionarySource.cpp src/Dictionaries/ExecutableDictionarySource.cpp diff --git a/dbms/cmake/find_vectorclass.cmake b/dbms/cmake/find_vectorclass.cmake deleted file mode 100644 index 62998a9b4b1..00000000000 --- a/dbms/cmake/find_vectorclass.cmake +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2023 PingCAP, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -option (ENABLE_VECTORCLASS "Faster math functions with vectorclass lib" OFF) - -if (ENABLE_VECTORCLASS) - - set (VECTORCLASS_INCLUDE_PATHS "${TiFlash_SOURCE_DIR}/contrib/vectorclass" CACHE STRING "Path of vectorclass library") - find_path (VECTORCLASS_INCLUDE_DIR NAMES vectorf128.h PATHS ${VECTORCLASS_INCLUDE_PATHS}) - - if (VECTORCLASS_INCLUDE_DIR) - set (USE_VECTORCLASS 1) - endif () - - message (STATUS "Using vectorclass=${USE_VECTORCLASS}: ${VECTORCLASS_INCLUDE_DIR}") - -endif () diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h index b57de57774a..ab5a43cbd19 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxAny.h @@ -49,6 +49,8 @@ struct SingleValueDataFixed : public CommonImpl using ColumnType = std::conditional_t, ColumnDecimal, ColumnVector>; public: + static bool needArena() { return false; } + bool has() const { return has_value; } void setCollators(const TiDB::TiDBCollators &) {} @@ -225,6 +227,8 @@ struct SingleValueDataString : public CommonImpl char small_data[MAX_SMALL_STRING_SIZE]{}; /// Including the terminating zero. public: + static bool needArena() { return true; } + bool has() const { return size >= 0; } const char * getData() const { return size <= MAX_SMALL_STRING_SIZE ? small_data : large_data; } @@ -439,6 +443,8 @@ struct SingleValueDataGeneric : public CommonImpl Field value; public: + static bool needArena() { return false; } + bool has() const { return !value.isNull(); } void setCollators(const TiDB::TiDBCollators &) {} @@ -792,6 +798,8 @@ class AggregateFunctionsSingleValue final } const char * getHeaderFilePath() const override { return __FILE__; } + + bool allocatesMemoryInArena() const override { return Data::needArena(); } }; } // namespace DB diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.h b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.h index 636d2192292..89023c00be4 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMinMaxWindow.h @@ -62,6 +62,8 @@ struct SingleValueDataFixedForWindow } public: + static bool needArena() { return false; } + void insertMaxResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } void insertMinResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } @@ -143,6 +145,8 @@ struct SingleValueDataStringForWindow } public: + static bool needArena() { return false; } + void insertMaxResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } void insertMinResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } @@ -200,6 +204,8 @@ struct SingleValueDataGenericForWindow } public: + static bool needArena() { return false; } + void insertMaxResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } void insertMinResultInto(IColumn & to) const { insertMinOrMaxResultInto(to); } diff --git a/dbms/src/Client/Connection.cpp b/dbms/src/Client/Connection.cpp index 5682a69c800..1f0345b8cca 100644 --- a/dbms/src/Client/Connection.cpp +++ b/dbms/src/Client/Connection.cpp @@ -298,8 +298,7 @@ void Connection::sendQuery( const String & query_id_, UInt64 stage, const Settings * settings, - const ClientInfo * client_info, - bool with_pending_data) + const ClientInfo * client_info) { if (!connected) connect(); @@ -349,11 +348,8 @@ void Connection::sendQuery( block_out.reset(); /// Send empty block which means end of data. - if (!with_pending_data) - { - sendData(Block()); - out->next(); - } + sendData(Block()); + out->next(); } @@ -404,67 +400,6 @@ void Connection::sendPreparedData(ReadBuffer & input, size_t size, const String out->next(); } - -void Connection::sendExternalTablesData(ExternalTablesData & data) -{ - if (data.empty()) - { - /// Send empty block, which means end of data transfer. - sendData(Block()); - return; - } - - Stopwatch watch; - size_t out_bytes = out ? out->count() : 0; - size_t maybe_compressed_out_bytes = maybe_compressed_out ? maybe_compressed_out->count() : 0; - size_t rows = 0; - - for (auto & elem : data) - { - elem.first->readPrefix(); - while (Block block = elem.first->read()) - { - rows += block.rows(); - sendData(block, elem.second); - } - elem.first->readSuffix(); - } - - /// Send empty block, which means end of data transfer. - sendData(Block()); - - out_bytes = out->count() - out_bytes; - maybe_compressed_out_bytes = maybe_compressed_out->count() - maybe_compressed_out_bytes; - - auto get_logging_msg = [&]() -> String { - const double elapsed_seconds = watch.elapsedSeconds(); - - FmtBuffer fmt_buf; - fmt_buf.fmtAppend( - "Sent data for {} external tables, total {} rows in {:.3f} sec., {:.3f} rows/sec., " - "{:.3f} MiB ({:.3f} MiB/sec.)", - data.size(), - rows, - elapsed_seconds, - 1.0 * rows / elapsed_seconds, - maybe_compressed_out_bytes / 1048576.0, - maybe_compressed_out_bytes / 1048576.0 / elapsed_seconds); - - if (compression == Protocol::Compression::Enable) - fmt_buf.fmtAppend( - ", compressed {:.3f} times to {:.3f} MiB ({:.3f} MiB/sec.)", - 1.0 * maybe_compressed_out_bytes / out_bytes, - out_bytes / 1048576.0, - out_bytes / 1048576.0 / elapsed_seconds); - else - fmt_buf.append(", no compression."); - return fmt_buf.toString(); - }; - - LOG_DEBUG(log_wrapper.get(), get_logging_msg()); -} - - bool Connection::poll(size_t timeout_microseconds) { return static_cast(*in).poll(timeout_microseconds); diff --git a/dbms/src/Client/Connection.h b/dbms/src/Client/Connection.h index df4b090c52d..43e7a75a167 100644 --- a/dbms/src/Client/Connection.h +++ b/dbms/src/Client/Connection.h @@ -37,11 +37,6 @@ namespace DB { class ClientInfo; -/// The stream of blocks reading from the table and its name -using ExternalTableData = std::pair; -/// Vector of pairs describing tables -using ExternalTablesData = std::vector; - class Connection; using ConnectionPtr = std::shared_ptr; @@ -166,14 +161,11 @@ class Connection : private boost::noncopyable const String & query_id_ = "", UInt64 stage = QueryProcessingStage::Complete, const Settings * settings = nullptr, - const ClientInfo * client_info = nullptr, - bool with_pending_data = false); + const ClientInfo * client_info = nullptr); void sendCancel(); /// Send block of data; if name is specified, server will write it to external (temporary) table of that name. void sendData(const Block & block, const String & name = ""); - /// Send all contents of external (temporary) tables. - void sendExternalTablesData(ExternalTablesData & data); /// Send prepared block of data (serialized and, if need, compressed), that will be read from 'input'. /// You could pass size of serialized/compressed block. diff --git a/dbms/src/Client/MultiplexedConnections.cpp b/dbms/src/Client/MultiplexedConnections.cpp index b886f99dd8d..1e038798777 100644 --- a/dbms/src/Client/MultiplexedConnections.cpp +++ b/dbms/src/Client/MultiplexedConnections.cpp @@ -71,34 +71,11 @@ MultiplexedConnections::MultiplexedConnections( block_extra_info = std::make_unique(); } -void MultiplexedConnections::sendExternalTablesData(std::vector & data) -{ - std::lock_guard lock(cancel_mutex); - - if (!sent_query) - throw Exception("Cannot send external tables data: query not yet sent.", ErrorCodes::LOGICAL_ERROR); - - if (data.size() != active_connection_count) - throw Exception("Mismatch between replicas and data sources", ErrorCodes::MISMATCH_REPLICAS_DATA_SOURCES); - - auto it = data.begin(); - for (ReplicaState & state : replica_states) - { - Connection * connection = state.connection; - if (connection != nullptr) - { - connection->sendExternalTablesData(*it); - ++it; - } - } -} - void MultiplexedConnections::sendQuery( const String & query, const String & query_id, UInt64 stage, - const ClientInfo * client_info, - bool with_pending_data) + const ClientInfo * client_info) { std::lock_guard lock(cancel_mutex); @@ -115,7 +92,7 @@ void MultiplexedConnections::sendQuery( if (connection == nullptr) throw Exception("MultiplexedConnections: Internal error", ErrorCodes::LOGICAL_ERROR); - connection->sendQuery(query, query_id, stage, &query_settings, client_info, with_pending_data); + connection->sendQuery(query, query_id, stage, &query_settings, client_info); } } else @@ -124,7 +101,7 @@ void MultiplexedConnections::sendQuery( if (connection == nullptr) throw Exception("MultiplexedConnections: Internal error", ErrorCodes::LOGICAL_ERROR); - connection->sendQuery(query, query_id, stage, &settings, client_info, with_pending_data); + connection->sendQuery(query, query_id, stage, &settings, client_info); } sent_query = true; diff --git a/dbms/src/Client/MultiplexedConnections.h b/dbms/src/Client/MultiplexedConnections.h index 56ccc2740ee..6da4a5c7b83 100644 --- a/dbms/src/Client/MultiplexedConnections.h +++ b/dbms/src/Client/MultiplexedConnections.h @@ -45,16 +45,12 @@ class MultiplexedConnections final : private boost::noncopyable const ThrottlerPtr & throttler_, bool append_extra_info); - /// Send all content of external tables to replicas. - void sendExternalTablesData(std::vector & data); - /// Send request to replicas. void sendQuery( const String & query, const String & query_id = "", UInt64 stage = QueryProcessingStage::Complete, - const ClientInfo * client_info = nullptr, - bool with_pending_data = false); + const ClientInfo * client_info = nullptr); /// Get packet from any replica. Connection::Packet receivePacket(); diff --git a/dbms/src/Columns/ColumnDecimal.cpp b/dbms/src/Columns/ColumnDecimal.cpp index 45df9eb6816..72aded07965 100644 --- a/dbms/src/Columns/ColumnDecimal.cpp +++ b/dbms/src/Columns/ColumnDecimal.cpp @@ -76,13 +76,14 @@ ALWAYS_INLINE inline char * serializeDecimal256Helper(char * dst, const Decimal2 dst += sizeof(size_t); const size_t limb_size = limb_count * sizeof(boost::multiprecision::limb_type); - inline_memcpy(dst, val.limbs(), limb_size); + memcpy(dst, val.limbs(), limb_size); dst += limb_size; return dst; } -ALWAYS_INLINE inline const char * deserializeDecimal256Helper(Decimal256 & value, const char * ptr) +ALWAYS_INLINE inline const char * deserializeDecimal256Helper(Decimal256 & new_value, const char * ptr) { + Decimal256 value; auto & val = value.value.backend(); size_t offset = 0; @@ -98,6 +99,7 @@ ALWAYS_INLINE inline const char * deserializeDecimal256Helper(Decimal256 & value val.normalize(); offset += limb_count * sizeof(boost::multiprecision::limb_type); + new_value = value; return ptr + offset; } diff --git a/dbms/src/Columns/ColumnString.cpp b/dbms/src/Columns/ColumnString.cpp index 6f1af3b1221..e22390ec00e 100644 --- a/dbms/src/Columns/ColumnString.cpp +++ b/dbms/src/Columns/ColumnString.cpp @@ -498,7 +498,9 @@ void ColumnString::countSerializeByteSizeForCmp( const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator) const { - // Skip decoding collator for bin collator so we can avoid counting code points, which may be slow. + // For now, sortKeyReservedSpaceMultipler() of bin collator(padding or non-padding) is 1. + // So bin collator will skip to decode collator. + // And other collators will first count code point then compute the needed memory. if (collator != nullptr && collator->sortKeyReservedSpaceMultipler() > 1) { if (nullmap != nullptr) @@ -1298,7 +1300,7 @@ void ColumnString::deserializeAndInsertFromPosImpl( pos[i] += sizeof(UInt32); chars.resize(char_size + str_size); - memcpySmallAllowReadWriteOverflow15(&chars[char_size], pos[i], str_size); + inline_memcpy(&chars[char_size], pos[i], str_size); char_size += str_size; offsets[prev_size + i] = char_size; pos[i] += str_size; @@ -1357,7 +1359,7 @@ void ColumnString::deserializeAndInsertFromPosForColumnArrayImpl( pos[i] += sizeof(UInt32); chars.resize(char_size + str_size); - memcpySmallAllowReadWriteOverflow15(&chars[char_size], pos[i], str_size); + inline_memcpy(&chars[char_size], pos[i], str_size); char_size += str_size; offsets[j] = char_size; @@ -1379,7 +1381,7 @@ void ColumnString::deserializeAndInsertFromPosForColumnArrayImpl( offsets[j] = char_size; } chars.resize(char_size); - memcpySmallAllowReadWriteOverflow15(&chars[prev_char_size], pos[i], char_size - prev_char_size); + inline_memcpy(&chars[prev_char_size], pos[i], char_size - prev_char_size); pos[i] += char_size - prev_char_size; } } diff --git a/dbms/src/Columns/ColumnString.h b/dbms/src/Columns/ColumnString.h index 494b3b0894c..6201507316b 100644 --- a/dbms/src/Columns/ColumnString.h +++ b/dbms/src/Columns/ColumnString.h @@ -119,7 +119,7 @@ class ColumnString final : public COWPtrHelper const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator) const; - template + template void serializeToPosImplType( PaddedPODArray & pos, size_t start, @@ -127,7 +127,7 @@ class ColumnString final : public COWPtrHelper const TiDB::TiDBCollatorPtr & collator, String * sort_key_container, const NullMap * nullmap) const; - template + template void serializeToPosImpl( PaddedPODArray & pos, size_t start, @@ -136,7 +136,7 @@ class ColumnString final : public COWPtrHelper String * sort_key_container, const NullMap * nullmap) const; - template + template void serializeToPosForColumnArrayImplType( PaddedPODArray & pos, size_t start, @@ -145,7 +145,12 @@ class ColumnString final : public COWPtrHelper const TiDB::TiDBCollatorPtr & collator, String * sort_key_container, const NullMap * nullmap) const; - template + template < + bool compare_semantics, + bool has_null, + bool need_decode_collator, + typename DerivedCollator, + bool has_nullmap> void serializeToPosForColumnArrayImpl( PaddedPODArray & pos, size_t start, diff --git a/dbms/src/Columns/FilterDescription.h b/dbms/src/Columns/FilterDescription.h index 394ad794314..4d3931825a8 100644 --- a/dbms/src/Columns/FilterDescription.h +++ b/dbms/src/Columns/FilterDescription.h @@ -19,7 +19,7 @@ namespace DB { -/// Support methods for implementation of WHERE, PREWHERE and HAVING. +/// Support methods for implementation of WHERE and HAVING. /// Analyze if the column for filter is constant thus filter is always false or always true. diff --git a/dbms/src/Columns/IColumn.h b/dbms/src/Columns/IColumn.h index 19bd1d2ea10..af3be2405d1 100644 --- a/dbms/src/Columns/IColumn.h +++ b/dbms/src/Columns/IColumn.h @@ -325,9 +325,9 @@ class IColumn : public COWPtr /// Deserialize and insert data from pos and forward each pos[i] to the end of serialized data. /// Note: /// 1. The pos pointer must not be nullptr. - /// 2. The memory of pos must be accessible to overflow 15 bytes(i.e. PaddedPODArray) for speeding up memcpy.(e.g. for ColumnString) - /// 3. If use_nt_align_buffer is true and AVX2 is enabled, non-temporal store may be used when data memory is aligned to FULL_VECTOR_SIZE_AVX2(64 bytes). - /// 4. If non-temporal store is used, the data will be copied to a align_buffer firstly and then flush to column data if full. After the + /// 2. If use_nt_align_buffer is true and AVX2 is enabled, non-temporal store may be used when data memory is aligned to FULL_VECTOR_SIZE_AVX2(64 bytes). + /// The memory of pos must be accessible to **overflow 15 bytes**(i.e. PaddedPODArray) for speeding up memcpy when use_nt_align_buffer is true. + /// 3. If non-temporal store is used, the data will be copied to a align_buffer firstly and then flush to column data if full. After the /// last call, flushNTAlignBuffer must be called to flush the remaining unaligned data from align_buffer into column data. During the /// process, any function that may change the alignment of column data should not be called otherwise a exception will be thrown. /// Example: diff --git a/dbms/src/Columns/VirtualColumnUtils.cpp b/dbms/src/Columns/VirtualColumnUtils.cpp index 4c492ba96ea..a01864d42f2 100644 --- a/dbms/src/Columns/VirtualColumnUtils.cpp +++ b/dbms/src/Columns/VirtualColumnUtils.cpp @@ -78,19 +78,17 @@ static ASTPtr buildWhereExpression(const ASTs & functions) void filterBlockWithQuery(const ASTPtr & query, Block & block, const Context & context) { const auto & select = typeid_cast(*query); - if (!select.where_expression && !select.prewhere_expression) + if (!select.where_expression) return; NameSet columns; for (const auto & it : block.getNamesAndTypesList()) columns.insert(it.name); - /// We will create an expression that evaluates the expressions in WHERE and PREWHERE, depending only on the existing columns. + /// We will create an expression that evaluates the expressions in WHERE, depending only on the existing columns. std::vector functions; if (select.where_expression) extractFunctions(select.where_expression, columns, functions); - if (select.prewhere_expression) - extractFunctions(select.prewhere_expression, columns, functions); ASTPtr expression_ast = buildWhereExpression(functions); if (!expression_ast) diff --git a/dbms/src/Columns/VirtualColumnUtils.h b/dbms/src/Columns/VirtualColumnUtils.h index c5f05b727b5..388990f172c 100644 --- a/dbms/src/Columns/VirtualColumnUtils.h +++ b/dbms/src/Columns/VirtualColumnUtils.h @@ -22,7 +22,7 @@ namespace DB::VirtualColumnUtils { -/// Leave in the block only the rows that fit under the WHERE clause and the PREWHERE clause of the query. +/// Leave in the block only the rows that fit under the WHERE clause of the query. /// Only elements of the outer conjunction are considered, depending only on the columns present in the block. /// Returns true if at least one row is discarded. void filterBlockWithQuery(const ASTPtr & query, Block & block, const Context & context); diff --git a/dbms/src/Columns/filterColumn.cpp b/dbms/src/Columns/filterColumn.cpp index 33a95159b88..35c0eb06fcb 100644 --- a/dbms/src/Columns/filterColumn.cpp +++ b/dbms/src/Columns/filterColumn.cpp @@ -30,6 +30,17 @@ namespace DB namespace { + +constexpr std::array MASKS = [] constexpr { + std::array masks = {}; + for (int i = 0; i < 64; ++i) + { + masks[i] = ~((1ULL << i) - 1); + } + masks[64] = 0; + return masks; +}(); + /// Implementation details of filterArraysImpl function, used as template parameter. /// Allow to build or not to build offsets array. @@ -134,45 +145,81 @@ void filterArraysImplGeneric( while (filt_pos < filt_end_aligned) { auto mask = ToBits64(filt_pos); + while (mask) + { + // 100011111000 -> index: 3, length: 5, mask: 100000000000 + size_t index = std::countr_zero(mask); + size_t length = std::countr_one(mask >> index); + copy_chunk(offsets_pos + index, length); + mask &= MASKS[index + length]; + } + + filt_pos += FILTER_SIMD_BYTES; + offsets_pos += FILTER_SIMD_BYTES; + } + + while (filt_pos < filt_end) + { + if (*filt_pos) + copy_chunk(offsets_pos, 1); + + ++filt_pos; + ++offsets_pos; + } +} + +/// filterImplAligned is used for aligned part of filter. +template +inline void filterImplAligned( + const UInt8 *& filt_pos, + const UInt8 *& filt_end_aligned, + const T *& data_pos, + Container & res_data) +{ + while (filt_pos < filt_end_aligned) + { + UInt64 mask = ToBits64(filt_pos); if likely (0 != mask) { - if (const auto prefix_to_copy = prefixToCopy(mask); 0xFF != prefix_to_copy) + if (const UInt8 prefix_to_copy = prefixToCopy(mask); 0xFF != prefix_to_copy) { - copy_chunk(offsets_pos, prefix_to_copy); + res_data.insert(data_pos, data_pos + prefix_to_copy); } else { - if (const auto suffix_to_copy = suffixToCopy(mask); 0xFF != suffix_to_copy) + if (const UInt8 suffix_to_copy = suffixToCopy(mask); 0xFF != suffix_to_copy) { - copy_chunk(offsets_pos + FILTER_SIMD_BYTES - suffix_to_copy, suffix_to_copy); + res_data.insert(data_pos + FILTER_SIMD_BYTES - suffix_to_copy, data_pos + FILTER_SIMD_BYTES); } else { while (mask) { size_t index = std::countr_zero(mask); - copy_chunk(offsets_pos + index, 1); + res_data.push_back(data_pos[index]); mask &= mask - 1; } } } } + // There is an alternative implementation which is similar to the one in filterArraysImplGeneric. + // But according to the micro benchmark, the below implementation is slower. + // So we choose to still use the above implementation. + // while (mask) + // { + // // 100011111000 -> index: 3, length: 5, mask: 100000000000 + // size_t index = std::countr_zero(mask); + // size_t length = std::countr_one(mask >> index); + // res_data.insert(data_pos + index, data_pos + index + length); + // mask &= MASKS[index + length]; + // } filt_pos += FILTER_SIMD_BYTES; - offsets_pos += FILTER_SIMD_BYTES; - } - - while (filt_pos < filt_end) - { - if (*filt_pos) - copy_chunk(offsets_pos, 1); - - ++filt_pos; - ++offsets_pos; + data_pos += FILTER_SIMD_BYTES; } } -} // namespace +} // namespace template void filterArraysImpl( @@ -239,49 +286,6 @@ INSTANTIATE(Float64) #undef INSTANTIATE -namespace -{ -template -inline void filterImplAligned( - const UInt8 *& filt_pos, - const UInt8 *& filt_end_aligned, - const T *& data_pos, - Container & res_data) -{ - while (filt_pos < filt_end_aligned) - { - UInt64 mask = ToBits64(filt_pos); - if likely (0 != mask) - { - if (const UInt8 prefix_to_copy = prefixToCopy(mask); 0xFF != prefix_to_copy) - { - res_data.insert(data_pos, data_pos + prefix_to_copy); - } - else - { - if (const UInt8 suffix_to_copy = suffixToCopy(mask); 0xFF != suffix_to_copy) - { - res_data.insert(data_pos + FILTER_SIMD_BYTES - suffix_to_copy, data_pos + FILTER_SIMD_BYTES); - } - else - { - while (mask) - { - size_t index = std::countr_zero(mask); - res_data.push_back(data_pos[index]); - mask &= mask - 1; - } - } - } - } - - filt_pos += FILTER_SIMD_BYTES; - data_pos += FILTER_SIMD_BYTES; - } -} -} // namespace - - template void filterImpl(const UInt8 * filt_pos, const UInt8 * filt_end, const T * data_pos, Container & res_data) { diff --git a/dbms/src/Columns/tests/bench_column_filter.cpp b/dbms/src/Columns/tests/bench_column_filter.cpp index 77f61ffdb21..4e6eec4ecf7 100644 --- a/dbms/src/Columns/tests/bench_column_filter.cpp +++ b/dbms/src/Columns/tests/bench_column_filter.cpp @@ -13,6 +13,7 @@ // limitations under the License. +#include #include #include #include @@ -130,6 +131,72 @@ ColumnPtr filterAVX2(ColumnPtr & col, IColumn::Filter & filt, ssize_t result_siz const UInt8 * filt_end = filt_pos + size; const Int64 * data_pos = &data[0]; + const UInt8 * filt_end_aligned = filt_pos + (filt_end - filt_pos) / FILTER_SIMD_BYTES * FILTER_SIMD_BYTES; + while (filt_pos < filt_end_aligned) + { + UInt64 mask = ToBits64(filt_pos); + if likely (0 != mask) + { + if (const UInt8 prefix_to_copy = prefixToCopy(mask); 0xFF != prefix_to_copy) + { + res_data.insert(data_pos, data_pos + prefix_to_copy); + } + else + { + if (const UInt8 suffix_to_copy = suffixToCopy(mask); 0xFF != suffix_to_copy) + { + res_data.insert(data_pos + FILTER_SIMD_BYTES - suffix_to_copy, data_pos + FILTER_SIMD_BYTES); + } + else + { + while (mask) + { + size_t index = std::countr_zero(mask); + res_data.push_back(data_pos[index]); + mask &= mask - 1; + } + } + } + } + + filt_pos += FILTER_SIMD_BYTES; + data_pos += FILTER_SIMD_BYTES; + } + + /// Process the tail. + while (filt_pos < filt_end) + { + if (*filt_pos) + res_data.push_back(*data_pos); + ++filt_pos; + ++data_pos; + } + + return res; +} + +ColumnPtr filterCurrent(ColumnPtr & col, IColumn::Filter & filt, ssize_t result_size_hint) +{ + const auto & data = typeid_cast *>(col.get())->getData(); + size_t size = col->size(); + if (size != filt.size()) + throw Exception("Size of filter doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH); + + auto res = ColumnVector::create(); + using Container = ColumnVector::Container; + Container & res_data = res->getData(); + + if (result_size_hint) + { + if (result_size_hint < 0) + result_size_hint = countBytesInFilter(filt); + res_data.reserve(result_size_hint); + } + + const UInt8 * filt_pos = &filt[0]; + const UInt8 * filt_end = filt_pos + size; + const Int64 * data_pos = &data[0]; + filterImpl(filt_pos, filt_end, data_pos, res_data); return res; @@ -139,6 +206,7 @@ enum class FilterVersion { SSE2, AVX2, + Current, }; template @@ -158,7 +226,15 @@ void columnFilter(benchmark::State & state, Args &&... args) { for (auto _ : state) { - auto t = filterSSE2(col, filter, set_n * sizeof(Int64)); + auto t = filterSSE2(col, filter, set_n); + benchmark::DoNotOptimize(t); + } + } + else if (version == FilterVersion::AVX2) + { + for (auto _ : state) + { + auto t = filterAVX2(col, filter, set_n); benchmark::DoNotOptimize(t); } } @@ -166,7 +242,7 @@ void columnFilter(benchmark::State & state, Args &&... args) { for (auto _ : state) { - auto t = filterAVX2(col, filter, set_n * sizeof(Int64)); + auto t = filterCurrent(col, filter, set_n); benchmark::DoNotOptimize(t); } } @@ -174,29 +250,42 @@ void columnFilter(benchmark::State & state, Args &&... args) BENCHMARK_CAPTURE(columnFilter, sse2_00, FilterVersion::SSE2, DEFAULT_BLOCK_SIZE, 0.00); BENCHMARK_CAPTURE(columnFilter, avx2_00, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.00); +BENCHMARK_CAPTURE(columnFilter, cur_00, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.00); BENCHMARK_CAPTURE(columnFilter, sse2_01, FilterVersion::SSE2, DEFAULT_BLOCK_SIZE, 0.01); BENCHMARK_CAPTURE(columnFilter, avx2_01, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.01); +BENCHMARK_CAPTURE(columnFilter, cur_01, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.01); BENCHMARK_CAPTURE(columnFilter, sse2_10, FilterVersion::SSE2, DEFAULT_BLOCK_SIZE, 0.10); BENCHMARK_CAPTURE(columnFilter, avx2_10, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.10); +BENCHMARK_CAPTURE(columnFilter, cur_10, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.10); BENCHMARK_CAPTURE(columnFilter, sse2_20, FilterVersion::SSE2, DEFAULT_BLOCK_SIZE, 0.20); BENCHMARK_CAPTURE(columnFilter, avx2_20, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.20); +BENCHMARK_CAPTURE(columnFilter, cur_20, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.20); BENCHMARK_CAPTURE(columnFilter, sse2_30, FilterVersion::SSE2, DEFAULT_BLOCK_SIZE, 0.30); BENCHMARK_CAPTURE(columnFilter, avx2_30, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.30); +BENCHMARK_CAPTURE(columnFilter, cur_30, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.30); BENCHMARK_CAPTURE(columnFilter, sse2_40, FilterVersion::SSE2, DEFAULT_BLOCK_SIZE, 0.40); BENCHMARK_CAPTURE(columnFilter, avx2_40, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.40); +BENCHMARK_CAPTURE(columnFilter, cur_40, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.40); BENCHMARK_CAPTURE(columnFilter, sse2_50, FilterVersion::SSE2, DEFAULT_BLOCK_SIZE, 0.50); BENCHMARK_CAPTURE(columnFilter, avx2_50, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.50); +BENCHMARK_CAPTURE(columnFilter, cur_50, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.50); BENCHMARK_CAPTURE(columnFilter, sse2_60, FilterVersion::SSE2, DEFAULT_BLOCK_SIZE, 0.60); BENCHMARK_CAPTURE(columnFilter, avx2_60, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.60); +BENCHMARK_CAPTURE(columnFilter, cur_60, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.60); BENCHMARK_CAPTURE(columnFilter, sse2_70, FilterVersion::SSE2, DEFAULT_BLOCK_SIZE, 0.70); BENCHMARK_CAPTURE(columnFilter, avx2_70, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.70); +BENCHMARK_CAPTURE(columnFilter, cur_70, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.70); BENCHMARK_CAPTURE(columnFilter, sse2_80, FilterVersion::SSE2, DEFAULT_BLOCK_SIZE, 0.80); BENCHMARK_CAPTURE(columnFilter, avx2_80, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.80); +BENCHMARK_CAPTURE(columnFilter, cur_80, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.80); BENCHMARK_CAPTURE(columnFilter, sse2_90, FilterVersion::SSE2, DEFAULT_BLOCK_SIZE, 0.90); BENCHMARK_CAPTURE(columnFilter, avx2_90, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.90); +BENCHMARK_CAPTURE(columnFilter, cur_90, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.90); BENCHMARK_CAPTURE(columnFilter, sse2_99, FilterVersion::SSE2, DEFAULT_BLOCK_SIZE, 0.99); BENCHMARK_CAPTURE(columnFilter, avx2_99, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.99); +BENCHMARK_CAPTURE(columnFilter, cur_99, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.99); BENCHMARK_CAPTURE(columnFilter, sse2_100, FilterVersion::SSE2, DEFAULT_BLOCK_SIZE, 1.00); BENCHMARK_CAPTURE(columnFilter, avx2_100, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 1.00); +BENCHMARK_CAPTURE(columnFilter, cur_100, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 1.00); } // namespace bench diff --git a/dbms/src/Columns/tests/bench_column_string_filter.cpp b/dbms/src/Columns/tests/bench_column_string_filter.cpp new file mode 100644 index 00000000000..426f4c3b9c4 --- /dev/null +++ b/dbms/src/Columns/tests/bench_column_string_filter.cpp @@ -0,0 +1,261 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include +#include +#include +#include +#include + +#include + +using namespace DB; + +namespace bench::String +{ + +IColumn::Filter createRandomFilter(size_t n, size_t set_n) +{ + assert(n >= set_n); + + IColumn::Filter filter(set_n, 1); + filter.resize_fill_zero(n, 0); + + std::random_device rd; + std::mt19937 g(rd()); + std::shuffle(filter.begin(), filter.end(), g); + return filter; +} + +struct ResultOffsetsBuilder +{ + IColumn::Offsets & res_offsets; + IColumn::Offset current_src_offset = 0; + + explicit ResultOffsetsBuilder(IColumn::Offsets * res_offsets_) + : res_offsets(*res_offsets_) + {} + + void reserve(size_t result_size_hint) { res_offsets.reserve(result_size_hint); } + + void insertChunk( + size_t n, + const IColumn::Offset * src_offsets_pos, + bool first, + IColumn::Offset chunk_offset, + size_t chunk_size) + { + const auto offsets_size_old = res_offsets.size(); + res_offsets.resize(offsets_size_old + n); + inline_memcpy(&res_offsets[offsets_size_old], src_offsets_pos, n * sizeof(IColumn::Offset)); + + if (!first) + { + /// difference between current and actual offset + const auto diff_offset = chunk_offset - current_src_offset; + + if (diff_offset > 0) + { + auto * res_offsets_pos = &res_offsets[offsets_size_old]; + + /// adjust offsets + for (size_t i = 0; i < n; ++i) + res_offsets_pos[i] -= diff_offset; + } + } + current_src_offset += chunk_size; + } +}; + +ColumnPtr filterAVX2(ColumnPtr & col, IColumn::Filter & filt, ssize_t result_size_hint) +{ + auto res = ColumnString::create(); + + const auto & src = typeid_cast(*col); + const auto & src_elems = src.getChars(); + const auto & src_offsets = src.getOffsets(); + + auto & res_elems = res->getChars(); + auto & res_offsets = res->getOffsets(); + + const size_t size = src_offsets.size(); + if (size != filt.size()) + throw Exception( + fmt::format("size of filter {} doesn't match size of column {}", filt.size(), size), + ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH); + + ResultOffsetsBuilder result_offsets_builder(&res_offsets); + + if (result_size_hint) + { + if (result_size_hint < 0) + result_size_hint = countBytesInFilter(filt); + + result_offsets_builder.reserve(result_size_hint); + + if (result_size_hint < 1000000000 && src_elems.size() < 1000000000) /// Avoid overflow. + res_elems.reserve((result_size_hint * src_elems.size() + size - 1) / size); + } + + const UInt8 * filt_pos = filt.data(); + const auto * filt_end = filt_pos + size; + + const auto * offsets_pos = src_offsets.data(); + const auto * offsets_begin = offsets_pos; + + /// copy n arrays from ending at *end_offset_ptr + const auto copy_chunk = [&](const IColumn::Offset * offset_ptr, size_t n) { + const auto first = offset_ptr == offsets_begin; + + const auto chunk_offset = first ? 0 : offset_ptr[-1]; + const auto chunk_size = offset_ptr[n - 1] - chunk_offset; + + result_offsets_builder.insertChunk(n, offset_ptr, first, chunk_offset, chunk_size); + + /// copy elements for n arrays at once + const auto elems_size_old = res_elems.size(); + res_elems.resize(elems_size_old + chunk_size); + inline_memcpy(&res_elems[elems_size_old], &src_elems[chunk_offset], chunk_size * sizeof(UInt8)); + }; + + const auto * filt_end_aligned = filt_pos + size / FILTER_SIMD_BYTES * FILTER_SIMD_BYTES; + while (filt_pos < filt_end_aligned) + { + auto mask = ToBits64(filt_pos); + if likely (0 != mask) + { + if (const auto prefix_to_copy = prefixToCopy(mask); 0xFF != prefix_to_copy) + { + copy_chunk(offsets_pos, prefix_to_copy); + } + else + { + if (const auto suffix_to_copy = suffixToCopy(mask); 0xFF != suffix_to_copy) + { + copy_chunk(offsets_pos + FILTER_SIMD_BYTES - suffix_to_copy, suffix_to_copy); + } + else + { + while (mask) + { + size_t index = std::countr_zero(mask); + copy_chunk(offsets_pos + index, 1); + mask &= mask - 1; + } + } + } + } + + filt_pos += FILTER_SIMD_BYTES; + offsets_pos += FILTER_SIMD_BYTES; + } + + while (filt_pos < filt_end) + { + if (*filt_pos) + copy_chunk(offsets_pos, 1); + + ++filt_pos; + ++offsets_pos; + } + + return res; +} + +ColumnPtr filterCurrent(ColumnPtr & col, IColumn::Filter & filt, ssize_t result_size_hint) +{ + auto res = ColumnString::create(); + + const auto & src = typeid_cast(*col); + const auto & src_elems = src.getChars(); + const auto & src_offsets = src.getOffsets(); + + auto & res_elems = res->getChars(); + auto & res_offsets = res->getOffsets(); + + filterArraysImpl(src_elems, src_offsets, res_elems, res_offsets, filt, result_size_hint); + + return res; +} + +enum class FilterVersion +{ + AVX2, + Current, +}; + +template +void columnStringFilter(benchmark::State & state, Args &&... args) +{ + auto [version, n, set_percent] = std::make_tuple(std::move(args)...); + auto mut_col = ColumnString::create(); + auto & src_chars = mut_col->getChars(); + auto & src_offsets = mut_col->getOffsets(); + src_chars.resize(n); + src_offsets.resize(n); + std::fill(src_offsets.begin(), src_offsets.end(), 'a'); + std::iota(src_offsets.begin(), src_offsets.end(), 0); + auto set_n = n * set_percent; + auto filter = createRandomFilter(n, set_n); + + ColumnPtr col = std::move(mut_col); + + if (version == FilterVersion::AVX2) + { + for (auto _ : state) + { + auto t = filterAVX2(col, filter, set_n); + benchmark::DoNotOptimize(t); + } + } + else + { + for (auto _ : state) + { + auto t = filterCurrent(col, filter, set_n); + benchmark::DoNotOptimize(t); + } + } +} + +BENCHMARK_CAPTURE(columnStringFilter, avx2_00, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.00); +BENCHMARK_CAPTURE(columnStringFilter, cur_00, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.00); +BENCHMARK_CAPTURE(columnStringFilter, avx2_01, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.01); +BENCHMARK_CAPTURE(columnStringFilter, cur_01, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.01); +BENCHMARK_CAPTURE(columnStringFilter, avx2_10, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.10); +BENCHMARK_CAPTURE(columnStringFilter, cur_10, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.10); +BENCHMARK_CAPTURE(columnStringFilter, avx2_20, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.20); +BENCHMARK_CAPTURE(columnStringFilter, cur_20, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.20); +BENCHMARK_CAPTURE(columnStringFilter, avx2_30, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.30); +BENCHMARK_CAPTURE(columnStringFilter, cur_30, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.30); +BENCHMARK_CAPTURE(columnStringFilter, avx2_40, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.40); +BENCHMARK_CAPTURE(columnStringFilter, cur_40, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.40); +BENCHMARK_CAPTURE(columnStringFilter, avx2_50, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.50); +BENCHMARK_CAPTURE(columnStringFilter, cur_50, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.50); +BENCHMARK_CAPTURE(columnStringFilter, avx2_60, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.60); +BENCHMARK_CAPTURE(columnStringFilter, cur_60, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.60); +BENCHMARK_CAPTURE(columnStringFilter, avx2_70, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.70); +BENCHMARK_CAPTURE(columnStringFilter, cur_70, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.70); +BENCHMARK_CAPTURE(columnStringFilter, avx2_80, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.80); +BENCHMARK_CAPTURE(columnStringFilter, cur_80, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.80); +BENCHMARK_CAPTURE(columnStringFilter, avx2_90, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.90); +BENCHMARK_CAPTURE(columnStringFilter, cur_90, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.90); +BENCHMARK_CAPTURE(columnStringFilter, avx2_99, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 0.99); +BENCHMARK_CAPTURE(columnStringFilter, cur_99, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 0.99); +BENCHMARK_CAPTURE(columnStringFilter, avx2_100, FilterVersion::AVX2, DEFAULT_BLOCK_SIZE, 1.00); +BENCHMARK_CAPTURE(columnStringFilter, cur_100, FilterVersion::Current, DEFAULT_BLOCK_SIZE, 1.00); + +} // namespace bench::String diff --git a/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp b/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp index 77709a01675..e3009589752 100644 --- a/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp +++ b/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp @@ -219,7 +219,8 @@ class TestColumnSerializeDeserialize : public ::testing::Test { if (compare_semantics) { - doTestSerializeAndDeserializeForCmp(column_ptr, compare_semantics, collator, sort_key_container); + doTestSerializeAndDeserializeForCmp(column_ptr, true, collator, sort_key_container); + doTestSerializeAndDeserializeForCmp(column_ptr, false, collator, sort_key_container); } else { @@ -525,6 +526,17 @@ try testSerializeAndDeserialize(col_decimal_256); testSerializeAndDeserialize(col_decimal_256, true, nullptr, nullptr); + + // Also test row-base interface for ColumnDecimal. + Arena arena; + const char * begin = nullptr; + String sort; + auto new_col_ptr = col_decimal_256->cloneEmpty(); + for (size_t i = 0; i < col_decimal_256->size(); ++i) + col_decimal_256->serializeValueIntoArena(i, arena, begin, nullptr, sort); + for (size_t i = 0; i < col_decimal_256->size(); ++i) + begin = new_col_ptr->deserializeAndInsertFromArena(begin, nullptr); + ASSERT_COLUMN_EQ(std::move(col_decimal_256), std::move(new_col_ptr)); } CATCH diff --git a/dbms/src/Common/Arena.h b/dbms/src/Common/Arena.h index b9999f6b179..a801991ee7b 100644 --- a/dbms/src/Common/Arena.h +++ b/dbms/src/Common/Arena.h @@ -95,14 +95,25 @@ class Arena : private boost::noncopyable } /// Add next contiguous chunk of memory with size not less than specified. - void NO_INLINE addChunk(size_t min_size) + /// If free_empty_head_chunk is true, empty head will be freed. + /// It can avoid mem leak when you want always reuse one chunk. + void NO_INLINE addChunk(size_t min_size, bool free_empty_head_chunk) { if (resize_callback != nullptr) { if unlikely (!resize_callback()) throw ResizeException("Error in arena resize"); } - head = new Chunk(nextSize(min_size), head); + const auto next_size = nextSize(min_size); + if (free_empty_head_chunk && head->remaining() == head->size()) + { + size_in_bytes -= head->size(); + auto * old_head = head; + head = head->prev; + old_head->prev = nullptr; + delete old_head; + } + head = new Chunk(next_size, head); size_in_bytes += head->size(); } @@ -122,7 +133,7 @@ class Arena : private boost::noncopyable ~Arena() { delete head; } /// Get piece of memory with alignment - char * alignedAlloc(size_t size, size_t alignment) + char * alignedAlloc(size_t size, size_t alignment, bool free_empty_head_chunk = false) { do { @@ -137,15 +148,15 @@ class Arena : private boost::noncopyable return res; } - addChunk(size + alignment); + addChunk(size + alignment, free_empty_head_chunk); } while (true); } /// Get piece of memory, without alignment. - char * alloc(size_t size) + char * alloc(size_t size, bool free_empty_head_chunk = false) { if (unlikely(head->pos + size > head->end)) - addChunk(size); + addChunk(size, free_empty_head_chunk); char * res = head->pos; head->pos += size; @@ -156,6 +167,7 @@ class Arena : private boost::noncopyable * Must pass size not more that was just allocated. */ void rollback(size_t size) { head->pos -= size; } + void rollback() { head->pos = head->begin; } void setResizeCallback(const ResizeCallback & resize_callback_) { resize_callback = resize_callback_; } @@ -169,7 +181,7 @@ class Arena : private boost::noncopyable while (unlikely(head->pos + size > head->end)) { char * prev_end = head->pos; - addChunk(size); + addChunk(size, false); if (begin) begin = insert(begin, prev_end - begin); diff --git a/dbms/src/Common/ColumnsHashing.h b/dbms/src/Common/ColumnsHashing.h index a03136bbed8..f06fea800a1 100644 --- a/dbms/src/Common/ColumnsHashing.h +++ b/dbms/src/Common/ColumnsHashing.h @@ -48,8 +48,10 @@ struct HashMethodOneNumber using Self = HashMethodOneNumber; using Base = columns_hashing_impl::HashMethodBase; using KeyHolderType = FieldType; + using BatchKeyHolderType = KeyHolderType; static constexpr bool is_serialized_key = false; + static constexpr bool can_batch_get_key_holder = false; const FieldType * vec; @@ -87,19 +89,108 @@ struct HashMethodOneNumber const FieldType * getKeyData() const { return vec; } }; +class KeyStringBatchHandlerBase +{ +private: + size_t processed_row_idx = 0; + std::vector sort_key_containers{}; + std::vector batch_rows{}; + + template + void prepareNextBatchType( + const UInt8 * chars, + const IColumn::Offsets & offsets, + size_t cur_batch_size, + const TiDB::TiDBCollatorPtr & collator) + { + if (cur_batch_size <= 0) + return; + + batch_rows.resize(cur_batch_size); + + const auto * derived_collator = static_cast(collator); + for (size_t i = 0; i < cur_batch_size; ++i) + { + const auto row = processed_row_idx + i; + const auto last_offset = offsets[row - 1]; + // Remove last zero byte. + StringRef key(chars + last_offset, offsets[row] - last_offset - 1); + if constexpr (has_collator) + key = derived_collator->sortKey(key.data, key.size, sort_key_containers[i]); + + batch_rows[i] = key; + } + processed_row_idx += cur_batch_size; + } + +protected: + bool inited() const { return !sort_key_containers.empty(); } + + void init(size_t start_row, size_t max_batch_size) + { + RUNTIME_CHECK(max_batch_size >= 256); + processed_row_idx = start_row; + sort_key_containers.resize(max_batch_size); + batch_rows.reserve(max_batch_size); + } + + void prepareNextBatch( + const UInt8 * chars, + const IColumn::Offsets & offsets, + size_t cur_batch_size, + const TiDB::TiDBCollatorPtr & collator) + { + if likely (collator && !collator->isTrivialCollator()) + { +#define M(VAR_PREFIX, COLLATOR_NAME, IMPL_TYPE, COLLATOR_ID) \ + case (COLLATOR_ID): \ + { \ + return prepareNextBatchType(chars, offsets, cur_batch_size, collator); \ + } + + switch (collator->getCollatorId()) + { + APPLY_FOR_COLLATOR_TYPES(M) + default: + { + throw Exception(fmt::format("unexpected collator: {}", collator->getCollatorId())); + } + }; +#undef M + } + else + { + return prepareNextBatchType(chars, offsets, cur_batch_size, collator); + } + } + +public: + // NOTE: i is the index of mini batch, it's not the row index of Column. + ALWAYS_INLINE inline ArenaKeyHolder getKeyHolderBatch(size_t i, Arena * pool) const + { + assert(inited()); + assert(i < batch_rows.size()); + return ArenaKeyHolder{batch_rows[i], pool}; + } +}; /// For the case when there is one string key. template struct HashMethodString : public columns_hashing_impl::HashMethodBase, Value, Mapped, use_cache> + , KeyStringBatchHandlerBase { using Self = HashMethodString; using Base = columns_hashing_impl::HashMethodBase; using KeyHolderType = ArenaKeyHolder; + using BatchKeyHolderType = KeyHolderType; + + using BatchHandlerBase = KeyStringBatchHandlerBase; static constexpr bool is_serialized_key = false; + static constexpr bool can_batch_get_key_holder = true; - const IColumn::Offset * offsets; + const IColumn::Offsets * offsets; const UInt8 * chars; TiDB::TiDBCollatorPtr collator = nullptr; @@ -110,20 +201,34 @@ struct HashMethodString { const IColumn & column = *key_columns[0]; const auto & column_string = assert_cast(column); - offsets = column_string.getOffsets().data(); + offsets = &column_string.getOffsets(); chars = column_string.getChars().data(); if (!collators.empty()) collator = collators[0]; } + void initBatchHandler(size_t start_row, size_t max_batch_size) + { + assert(!BatchHandlerBase::inited()); + BatchHandlerBase::init(start_row, max_batch_size); + } + + void prepareNextBatch(Arena *, size_t cur_batch_size) + { + assert(BatchHandlerBase::inited()); + BatchHandlerBase::prepareNextBatch(chars, *offsets, cur_batch_size, collator); + } + ALWAYS_INLINE inline KeyHolderType getKeyHolder( ssize_t row, [[maybe_unused]] Arena * pool, std::vector & sort_key_containers) const { - auto last_offset = row == 0 ? 0 : offsets[row - 1]; + assert(!BatchHandlerBase::inited()); + + auto last_offset = (*offsets)[row - 1]; // Remove last zero byte. - StringRef key(chars + last_offset, offsets[row] - last_offset - 1); + StringRef key(chars + last_offset, (*offsets)[row] - last_offset - 1); if (likely(collator)) key = collator->sortKey(key.data, key.size, sort_key_containers[0]); @@ -141,8 +246,10 @@ struct HashMethodStringBin using Self = HashMethodStringBin; using Base = columns_hashing_impl::HashMethodBase; using KeyHolderType = ArenaKeyHolder; + using BatchKeyHolderType = KeyHolderType; static constexpr bool is_serialized_key = false; + static constexpr bool can_batch_get_key_holder = false; const IColumn::Offset * offsets; const UInt8 * chars; @@ -167,208 +274,6 @@ struct HashMethodStringBin friend class columns_hashing_impl::HashMethodBase; }; -/* -/// For the case when there is multi string key. -template -struct HashMethodMultiString - : public columns_hashing_impl::HashMethodBase, Value, Mapped, false> -{ - using Self = HashMethodMultiString; - using Base = columns_hashing_impl::HashMethodBase; - - std::vector offsets; - std::vector chars; - TiDB::TiDBCollators collators; - bool all_collators_padding_bin = false; - - HashMethodMultiString(const ColumnRawPtrs & key_columns, const Sizes &, const TiDB::TiDBCollators & collators_) - : collators(collators_) - { - size_t num = key_columns.size(); - offsets.resize(num); - chars.resize(num); - - for (size_t i = 0; i < num; ++i) - { - const IColumn & column = *key_columns[i]; - const auto & column_string = assert_cast(column); - offsets[i] = column_string.getOffsets().data(); - chars[i] = column_string.getChars().data(); - } - if (!collators.empty()) - { - all_collators_padding_bin = std::all_of(collators.begin(), collators.end(), [](auto & x) { - return x->isPaddingBinary(); - }); - } - } - - template - ALWAYS_INLINE inline SerializedKeyHolder genSerializedKeyHolder(ssize_t row, Arena * pool, F && fn_handle_key) const - { - auto num = offsets.size(); - - const char * begin = nullptr; - size_t sum_size = 0; - - for (size_t key_index = 0; key_index < num; ++key_index) - { - auto last_offset = row == 0 ? 0 : offsets[key_index][row - 1]; - StringRef key(chars[key_index] + last_offset, offsets[key_index][row] - last_offset - 1); - - key = fn_handle_key(key_index, key); - - char * pos = pool->allocContinue(key.size + sizeof(key.size), begin); - { - memcpy(pos, &key.size, sizeof(key.size)); - inline_memcpy(pos + sizeof(key.size), key.data, key.size); - } - - sum_size += key.size + sizeof(key.size); - } - return SerializedKeyHolder{{begin, sum_size}, *pool}; - } - - ALWAYS_INLINE inline auto getKeyHolder(ssize_t row, Arena * pool, std::vector & sort_key_containers) const - { - if (likely(all_collators_padding_bin)) - { - return genSerializedKeyHolder(row, pool, [](size_t, StringRef key) { - return DB::BinCollatorSortKey(key.data, key.size); - }); - } - - if (unlikely(collators.empty())) - { - return genSerializedKeyHolder(row, pool, [](size_t, StringRef key) { - return key; - }); - } - else - { - return genSerializedKeyHolder(row, pool, [&](size_t key_index, StringRef key) { - if (collators[key_index]) - return collators[key_index]->sortKey(key.data, key.size, sort_key_containers[key_index]); - return key; - }); - } - } - -protected: - friend class columns_hashing_impl::HashMethodBase; -}; -*/ - -static_assert(std::is_same_v(0)->size)>); - -struct KeyDescNumber64 -{ - using ColumnType = ColumnUInt64; - using AllocSize = size_t; - static constexpr size_t ElementSize = sizeof(ColumnType::value_type); - - explicit KeyDescNumber64(const IColumn * key_column_) { column = static_cast(key_column_); } - static inline void serializeKey(char *& pos, const StringRef & ref) - { - std::memcpy(pos, ref.data, ElementSize); - pos += ElementSize; - } - ALWAYS_INLINE inline AllocSize getKey(ssize_t row, StringRef & ref) const - { - const auto & element = column->getElement(row); - ref = {reinterpret_cast(&element), ElementSize}; - return ElementSize; - } - const ColumnType * column{}; -}; - -struct KeyDescStringBin -{ - using ColumnType = ColumnString; - using AllocSize = size_t; - - explicit KeyDescStringBin(const IColumn * key_column_) { column = static_cast(key_column_); } - static inline void serializeKey(char *& pos, const StringRef & ref) - { - std::memcpy(pos, &ref.size, sizeof(ref.size)); - pos += sizeof(ref.size); - inline_memcpy(pos, ref.data, ref.size); - pos += ref.size; - } - - template - ALWAYS_INLINE inline AllocSize getKeyImpl(ssize_t row, StringRef & key, F && fn_handle_key) const - { - const auto * offsets = column->getOffsets().data(); - const auto * chars = column->getChars().data(); - - size_t last_offset = 0; - if (likely(row != 0)) - last_offset = offsets[row - 1]; - - key = {chars + last_offset, offsets[row] - last_offset - 1}; - key = fn_handle_key(key); - - return key.size + sizeof(key.size); - } - - ALWAYS_INLINE inline AllocSize getKey(ssize_t row, StringRef & ref) const - { - return getKeyImpl(row, ref, [](StringRef key) { return key; }); - } - - const ColumnType * column{}; -}; - -struct KeyDescStringBinPadding : KeyDescStringBin -{ - explicit KeyDescStringBinPadding(const IColumn * key_column_) - : KeyDescStringBin(key_column_) - {} - - ALWAYS_INLINE inline AllocSize getKey(ssize_t row, StringRef & ref) const - { - return getKeyImpl(row, ref, [](StringRef key) { return DB::BinCollatorSortKey(key.data, key.size); }); - } -}; - -/// For the case when there are 2 keys. -template -struct HashMethodFastPathTwoKeysSerialized - : public columns_hashing_impl:: - HashMethodBase, Value, Mapped, false> -{ - using Self = HashMethodFastPathTwoKeysSerialized; - using Base = columns_hashing_impl::HashMethodBase; - using KeyHolderType = SerializedKeyHolder; - - static constexpr bool is_serialized_key = true; - - Key1Desc key_1_desc; - Key2Desc key_2_desc; - - HashMethodFastPathTwoKeysSerialized(const ColumnRawPtrs & key_columns, const Sizes &, const TiDB::TiDBCollators &) - : key_1_desc(key_columns[0]) - , key_2_desc(key_columns[1]) - {} - - ALWAYS_INLINE inline KeyHolderType getKeyHolder(ssize_t row, Arena * pool, std::vector &) const - { - StringRef key1; - StringRef key2; - size_t alloc_size = key_1_desc.getKey(row, key1) + key_2_desc.getKey(row, key2); - char * start = pool->alloc(alloc_size); - SerializedKeyHolder ret{{start, alloc_size}, pool}; - Key1Desc::serializeKey(start, key1); - Key2Desc::serializeKey(start, key2); - return ret; - } - -protected: - friend class columns_hashing_impl::HashMethodBase; -}; - - /// For the case when there is one fixed-length string key. template struct HashMethodFixedString @@ -378,8 +283,10 @@ struct HashMethodFixedString using Self = HashMethodFixedString; using Base = columns_hashing_impl::HashMethodBase; using KeyHolderType = ArenaKeyHolder; + using BatchKeyHolderType = KeyHolderType; static constexpr bool is_serialized_key = false; + static constexpr bool can_batch_get_key_holder = false; size_t n; const ColumnFixedString::Chars_t * chars; @@ -426,8 +333,10 @@ struct HashMethodKeysFixed using BaseHashed = columns_hashing_impl::HashMethodBase; using Base = columns_hashing_impl::BaseStateKeysFixed; using KeyHolderType = Key; + using BatchKeyHolderType = KeyHolderType; static constexpr bool is_serialized_key = false; + static constexpr bool can_batch_get_key_holder = false; static constexpr bool has_nullable_keys = has_nullable_keys_; Sizes key_sizes; @@ -570,6 +479,96 @@ struct HashMethodKeysFixed } }; +class KeySerializedBatchHandlerBase +{ +private: + size_t processed_row_idx = 0; + String sort_key_container{}; + PaddedPODArray byte_size{}; + PaddedPODArray pos{}; + PaddedPODArray ori_pos{}; + PaddedPODArray real_byte_size{}; + + ALWAYS_INLINE inline void santityCheck() const + { + assert(ori_pos.size() == pos.size() && real_byte_size.size() == pos.size()); + } + + ALWAYS_INLINE inline void resize(size_t batch_size) + { + pos.resize(batch_size); + ori_pos.resize(batch_size); + real_byte_size.resize(batch_size); + } + +protected: + bool inited() const { return !byte_size.empty(); } + + void init(size_t start_row, const ColumnRawPtrs & key_columns, const TiDB::TiDBCollators & collators) + { + // When start_row is not 0, byte_size will be re-initialized for the same block. + // However, the situation where start_row != 0 will only occur when spilling happens, + // so there is no need to consider the performance impact of repeatedly calling countSerializeByteSizeForCmp. + processed_row_idx = start_row; + byte_size.resize_fill_zero(key_columns[0]->size()); + RUNTIME_CHECK(!byte_size.empty()); + for (size_t i = 0; i < key_columns.size(); ++i) + key_columns[i]->countSerializeByteSizeForCmp( + byte_size, + nullptr, + collators.empty() ? nullptr : collators[i]); + } + + void prepareNextBatch( + const ColumnRawPtrs & key_columns, + Arena * pool, + size_t cur_batch_size, + const TiDB::TiDBCollators & collators) + { + santityCheck(); + resize(cur_batch_size); + + if unlikely (cur_batch_size <= 0) + return; + + assert(processed_row_idx + cur_batch_size <= byte_size.size()); + size_t mem_size = 0; + for (size_t i = processed_row_idx; i < processed_row_idx + cur_batch_size; ++i) + mem_size += byte_size[i]; + + auto * ptr = static_cast(pool->alignedAlloc(mem_size, 16, /*free_empty_head_chunk=*/true)); + for (size_t i = 0; i < cur_batch_size; ++i) + { + pos[i] = ptr; + ori_pos[i] = ptr; + ptr += byte_size[i + processed_row_idx]; + } + + for (size_t i = 0; i < key_columns.size(); ++i) + key_columns[i]->serializeToPosForCmp( + pos, + processed_row_idx, + cur_batch_size, + false, + nullptr, + collators.empty() ? nullptr : collators[i], + &sort_key_container); + + for (size_t i = 0; i < cur_batch_size; ++i) + real_byte_size[i] = pos[i] - ori_pos[i]; + + processed_row_idx += cur_batch_size; + } + +public: + // NOTE: i is the index of mini batch, it's not the row index of Column. + ALWAYS_INLINE inline ArenaKeyHolder getKeyHolderBatch(size_t i, Arena * pool) const + { + santityCheck(); + assert(i < ori_pos.size()); + return ArenaKeyHolder{StringRef{ori_pos[i], real_byte_size[i]}, pool}; + } +}; /** Hash by concatenating serialized key values. * The serialized value differs in that it uniquely allows to deserialize it, having only the position with which it starts. @@ -579,12 +578,16 @@ struct HashMethodKeysFixed template struct HashMethodSerialized : public columns_hashing_impl::HashMethodBase, Value, Mapped, false> + , KeySerializedBatchHandlerBase { using Self = HashMethodSerialized; using Base = columns_hashing_impl::HashMethodBase; + using BatchHandlerBase = KeySerializedBatchHandlerBase; using KeyHolderType = SerializedKeyHolder; + using BatchKeyHolderType = ArenaKeyHolder; static constexpr bool is_serialized_key = true; + static constexpr bool can_batch_get_key_holder = true; ColumnRawPtrs key_columns; size_t keys_size; @@ -599,9 +602,22 @@ struct HashMethodSerialized , collators(collators_) {} + void initBatchHandler(size_t start_row, size_t /* max_batch_size */) + { + assert(!BatchHandlerBase::inited()); + BatchHandlerBase::init(start_row, key_columns, collators); + } + + void prepareNextBatch(Arena * pool, size_t cur_batch_size) + { + assert(BatchHandlerBase::inited()); + BatchHandlerBase::prepareNextBatch(key_columns, pool, cur_batch_size, collators); + } + ALWAYS_INLINE inline KeyHolderType getKeyHolder(size_t row, Arena * pool, std::vector & sort_key_containers) const { + assert(!BatchHandlerBase::inited()); return SerializedKeyHolder{ serializeKeysToPoolContiguous(row, keys_size, key_columns, collators, sort_key_containers, *pool), pool}; @@ -620,8 +636,10 @@ struct HashMethodHashed using Self = HashMethodHashed; using Base = columns_hashing_impl::HashMethodBase; using KeyHolderType = Key; + using BatchKeyHolderType = KeyHolderType; static constexpr bool is_serialized_key = false; + static constexpr bool can_batch_get_key_holder = false; ColumnRawPtrs key_columns; TiDB::TiDBCollators collators; diff --git a/dbms/src/Common/ColumnsHashingImpl.h b/dbms/src/Common/ColumnsHashingImpl.h index bf130d2bd29..4f6c56be6d4 100644 --- a/dbms/src/Common/ColumnsHashingImpl.h +++ b/dbms/src/Common/ColumnsHashingImpl.h @@ -140,6 +140,12 @@ class HashMethodBase return emplaceImpl(key_holder, data); } + template + ALWAYS_INLINE inline EmplaceResult emplaceKey(Data & data, KeyHolder && key_holder) + { + return emplaceImpl(key_holder, data); + } + template ALWAYS_INLINE inline EmplaceResult emplaceKey(Data & data, KeyHolder && key_holder, size_t hashval) { @@ -157,6 +163,12 @@ class HashMethodBase return findKeyImpl(keyHolderGetKey(key_holder), data); } + template + ALWAYS_INLINE inline FindResult findKey(Data & data, KeyHolder && key_holder) + { + return findKeyImpl(keyHolderGetKey(key_holder), data); + } + template ALWAYS_INLINE inline FindResult findKey(Data & data, KeyHolder && key_holder, size_t hashval) { diff --git a/dbms/src/Common/ExternalTable.h b/dbms/src/Common/ExternalTable.h deleted file mode 100644 index a62c9186b1e..00000000000 --- a/dbms/src/Common/ExternalTable.h +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - - -namespace DB -{ -namespace ErrorCodes -{ -extern const int BAD_ARGUMENTS; -} - - -/// The base class containing the basic information about external table and -/// basic functions for extracting this information from text fields. -class BaseExternalTable -{ -public: - std::string file; /// File with data or '-' if stdin - std::string name; /// The name of the table - std::string format; /// Name of the data storage format - - /// Description of the table structure: (column name, data type name) - std::vector> structure; - - std::unique_ptr read_buffer; - Block sample_block; - - virtual ~BaseExternalTable() = default; - - /// Initialize read_buffer, depending on the data source. By default, does nothing. - virtual void initReadBuffer(){}; - - /// Get the table data - a pair (a thread with the contents of the table, the name of the table) - ExternalTableData getData(const Context & context) - { - initReadBuffer(); - initSampleBlock(); - ExternalTableData res = std::make_pair( - std::make_shared( - context.getInputFormat(format, *read_buffer, sample_block, DEFAULT_BLOCK_SIZE)), - name); - return res; - } - -protected: - /// Clear all accumulated information - void clean() - { - name = ""; - file = ""; - format = ""; - structure.clear(); - sample_block = Block(); - read_buffer.reset(); - } - - /// Function for debugging information output - void write() - { - std::cerr << "file " << file << std::endl; - std::cerr << "name " << name << std::endl; - std::cerr << "format " << format << std::endl; - std::cerr << "structure: \n"; - for (const auto & col_dt : structure) - std::cerr << "\t" << col_dt.first << " " << col_dt.second << std::endl; - } - - static std::vector split(const std::string & s, const std::string & d) - { - std::vector res; - boost::split(res, s, boost::algorithm::is_any_of(d), boost::algorithm::token_compress_on); - return res; - } - - /// Construct the `structure` vector from the text field `structure` - virtual void parseStructureFromStructureField(const std::string & argument) - { - std::vector vals = split(argument, " ,"); - - if (vals.size() & 1) - throw Exception("Odd number of attributes in section structure", ErrorCodes::BAD_ARGUMENTS); - - for (size_t i = 0; i < vals.size(); i += 2) - structure.emplace_back(vals[i], vals[i + 1]); - } - - /// Construct the `structure` vector from the text field `types` - virtual void parseStructureFromTypesField(const std::string & argument) - { - std::vector vals = split(argument, " ,"); - - for (size_t i = 0; i < vals.size(); ++i) - structure.emplace_back("_" + toString(i + 1), vals[i]); - } - -private: - /// Initialize sample_block according to the structure of the table stored in the `structure` - void initSampleBlock() - { - const DataTypeFactory & data_type_factory = DataTypeFactory::instance(); - - for (const auto & col_dt : structure) - { - ColumnWithTypeAndName column; - column.name = col_dt.first; - column.type = data_type_factory.get(col_dt.second); - column.column = column.type->createColumn(); - sample_block.insert(std::move(column)); - } - } -}; - - -/// Parsing of external table used in the tcp client. -class ExternalTable : public BaseExternalTable -{ -public: - void initReadBuffer() override - { - if (file == "-") - read_buffer = std::make_unique(STDIN_FILENO); - else - read_buffer = std::make_unique(file); - } - - /// Extract parameters from variables_map, which is built on the client command line - explicit ExternalTable(const boost::program_options::variables_map & external_options) - { - if (external_options.count("file")) - file = external_options["file"].as(); - else - throw Exception("--file field have not been provided for external table", ErrorCodes::BAD_ARGUMENTS); - - if (external_options.count("name")) - name = external_options["name"].as(); - else - throw Exception("--name field have not been provided for external table", ErrorCodes::BAD_ARGUMENTS); - - if (external_options.count("format")) - format = external_options["format"].as(); - else - throw Exception("--format field have not been provided for external table", ErrorCodes::BAD_ARGUMENTS); - - if (external_options.count("structure")) - parseStructureFromStructureField(external_options["structure"].as()); - else if (external_options.count("types")) - parseStructureFromTypesField(external_options["types"].as()); - else - throw Exception( - "Neither --structure nor --types have not been provided for external table", - ErrorCodes::BAD_ARGUMENTS); - } -}; - -/// Parsing of external table used when sending tables via http -/// The `handlePart` function will be called for each table passed, -/// so it's also necessary to call `clean` at the end of the `handlePart`. -class ExternalTablesHandler - : public Poco::Net::PartHandler - , BaseExternalTable -{ -public: - std::vector names; - - ExternalTablesHandler(Context & context_, Poco::Net::NameValueCollection params_) - : context(context_) - , params(params_) - {} - - void handlePart(const Poco::Net::MessageHeader & header, std::istream & stream) override - { - /// The buffer is initialized here, not in the virtual function initReadBuffer - read_buffer = std::make_unique(stream); - - /// Retrieve a collection of parameters from MessageHeader - Poco::Net::NameValueCollection content; - std::string label; - Poco::Net::MessageHeader::splitParameters(header.get("Content-Disposition"), label, content); - - /// Get parameters - name = content.get("name", "_data"); - format = params.get(name + "_format", "TabSeparated"); - - if (params.has(name + "_structure")) - parseStructureFromStructureField(params.get(name + "_structure")); - else if (params.has(name + "_types")) - parseStructureFromTypesField(params.get(name + "_types")); - else - throw Exception( - "Neither structure nor types have not been provided for external table " + name + ". Use fields " + name - + "_structure or " + name + "_types to do so.", - ErrorCodes::BAD_ARGUMENTS); - - ExternalTableData data = getData(context); - - /// Create table - NamesAndTypesList columns = sample_block.getNamesAndTypesList(); - StoragePtr storage = StorageMemory::create(data.second, ColumnsDescription{columns}); - storage->startup(); - context.addExternalTable(data.second, storage); - BlockOutputStreamPtr output = storage->write(ASTPtr(), context.getSettingsRef()); - - /// Write data - data.first->readPrefix(); - output->writePrefix(); - while (Block block = data.first->read()) - output->write(block); - data.first->readSuffix(); - output->writeSuffix(); - - names.push_back(name); - /// We are ready to receive the next file, for this we clear all the information received - clean(); - } - -private: - Context & context; - Poco::Net::NameValueCollection params; -}; - - -} // namespace DB diff --git a/dbms/src/Common/FailPoint.cpp b/dbms/src/Common/FailPoint.cpp index 30e040bafda..3c7dbf06f63 100644 --- a/dbms/src/Common/FailPoint.cpp +++ b/dbms/src/Common/FailPoint.cpp @@ -118,6 +118,7 @@ namespace DB M(force_agg_on_partial_block) \ M(force_agg_prefetch) \ M(force_magic_hash) \ + M(disable_agg_batch_get_key_holder) \ M(force_set_fap_candidate_store_id) \ M(force_not_clean_fap_on_destroy) \ M(force_fap_worker_throw) \ @@ -127,7 +128,9 @@ namespace DB M(force_thread_0_no_agg_spill) \ M(force_checkpoint_dump_throw_datafile) \ M(force_semi_join_time_exceed) \ - M(force_set_proxy_state_machine_cpu_cores) + M(force_set_proxy_state_machine_cpu_cores) \ + M(force_join_v2_probe_enable_lm) \ + M(force_join_v2_probe_disable_lm) #define APPLY_FOR_PAUSEABLE_FAILPOINTS_ONCE(M) \ M(pause_with_alter_locks_acquired) \ @@ -419,6 +422,8 @@ void FailPointHelper::wait(const String &) {} void FailPointHelper::initRandomFailPoints(Poco::Util::LayeredConfiguration &, const LoggerPtr &) {} void FailPointHelper::enableRandomFailPoint(const String &, double) {} + +void FailPointHelper::disableRandomFailPoints(Poco::Util::LayeredConfiguration &, const LoggerPtr &) {} #endif } // namespace DB diff --git a/dbms/src/Common/TiFlashMetrics.h b/dbms/src/Common/TiFlashMetrics.h index 12d194eb2e3..428e6b0c3b8 100644 --- a/dbms/src/Common/TiFlashMetrics.h +++ b/dbms/src/Common/TiFlashMetrics.h @@ -860,6 +860,7 @@ static_assert(RAFT_REGION_BIG_WRITE_THRES * 4 < RAFT_REGION_BIG_WRITE_MAX, "Inva "Vector index memory usage", \ Gauge, \ F(type_build, {"type", "build"}), \ + F(type_load, {"type", "load"}), \ F(type_view, {"type", "view"})) \ M(tiflash_vector_index_build_count, \ "Vector index build count", \ @@ -870,11 +871,26 @@ static_assert(RAFT_REGION_BIG_WRITE_THRES * 4 < RAFT_REGION_BIG_WRITE_MAX, "Inva "Active Vector index instances", \ Gauge, \ F(type_build, {"type", "build"}), \ + F(type_load, {"type", "load"}), \ F(type_view, {"type", "view"})) \ M(tiflash_vector_index_duration, \ "Vector index operation duration", \ Histogram, \ F(type_build, {{"type", "build"}}, ExpBuckets{0.001, 2, 20}), \ + F(type_load_cf, {{"type", "load_cf"}}, ExpBuckets{0.001, 2, 20}), \ + F(type_load_cache, {{"type", "load_cache"}}, ExpBuckets{0.001, 2, 20}), \ + F(type_load_dmfile_local, {{"type", "load_dmfile_local"}}, ExpBuckets{0.001, 2, 20}), \ + F(type_load_dmfile_s3, {{"type", "load_dmfile_s3"}}, ExpBuckets{0.001, 2, 20}), \ + F(type_search, {{"type", "search"}}, ExpBuckets{0.001, 2, 20})) \ + M(tiflash_inverted_index_active_instances, \ + "Active Inverted index instances", \ + Gauge, \ + F(type_build, {"type", "build"}), \ + F(type_view, {"type", "view"})) \ + M(tiflash_inverted_index_duration, \ + "Inverted index operation duration", \ + Histogram, \ + F(type_build, {{"type", "build"}}, ExpBuckets{0.001, 2, 20}), \ F(type_download, {{"type", "download"}}, ExpBuckets{0.001, 2, 20}), \ F(type_view, {{"type", "view"}}, ExpBuckets{0.001, 2, 20}), \ F(type_search, {{"type", "search"}}, ExpBuckets{0.001, 2, 20})) \ diff --git a/dbms/src/Common/config.h.in b/dbms/src/Common/config.h.in index 9544c71b768..a69043d130f 100644 --- a/dbms/src/Common/config.h.in +++ b/dbms/src/Common/config.h.in @@ -3,7 +3,6 @@ // .h autogenerated by cmake! #cmakedefine01 USE_RE2_ST -#cmakedefine01 USE_VECTORCLASS #cmakedefine01 Poco_NetSSL_FOUND #cmakedefine01 USE_GM_SSL #cmakedefine01 USE_QPL diff --git a/dbms/src/Common/config_build.cpp.in b/dbms/src/Common/config_build.cpp.in index bd1474af693..2faabe7c68f 100644 --- a/dbms/src/Common/config_build.cpp.in +++ b/dbms/src/Common/config_build.cpp.in @@ -57,8 +57,6 @@ const char * auto_config_build[]{ "@USE_QPL@", "USE_RE2_ST", "@USE_RE2_ST@", - "USE_VECTORCLASS", - "@USE_VECTORCLASS@", "USE_Poco_NetSSL", "@Poco_NetSSL_FOUND@", diff --git a/dbms/src/Core/Block.cpp b/dbms/src/Core/Block.cpp index b04e25275af..71b39c7ba6c 100644 --- a/dbms/src/Core/Block.cpp +++ b/dbms/src/Core/Block.cpp @@ -732,6 +732,8 @@ void Block::swap(Block & other) noexcept std::swap(info, other.info); data.swap(other.data); index_by_name.swap(other.index_by_name); + std::swap(start_offset, other.start_offset); + std::swap(segment_row_id_col, other.segment_row_id_col); std::swap(rs_result, other.rs_result); } diff --git a/dbms/src/DataStreams/AddingDefaultBlockOutputStream.cpp b/dbms/src/DataStreams/AddingDefaultBlockOutputStream.cpp index 7cd3c071564..2039d852589 100644 --- a/dbms/src/DataStreams/AddingDefaultBlockOutputStream.cpp +++ b/dbms/src/DataStreams/AddingDefaultBlockOutputStream.cpp @@ -18,10 +18,33 @@ #include #include #include +#include namespace DB { +AddingDefaultBlockOutputStream::AddingDefaultBlockOutputStream( + const StoragePtr & storage_, + const ASTPtr & query_ptr_, + const Block & header_, + NamesAndTypesList required_columns_, + const ColumnDefaults & column_defaults_, + const Context & context_) + : storage(storage_) + , header(header_) + , required_columns(required_columns_) + , column_defaults(column_defaults_) + , context(context_) + , query_ptr(query_ptr_) +{ + /** Notice + * This is a very important line. At any insertion into the table one of streams should own lock. + * Although now any insertion into the table is done via AddingDefaultBlockOutputStream, + * but it's clear that here is not the best place for this functionality. + */ + addTableLock(storage->lockForShare(context.getCurrentQueryId())); + output = storage->write(query_ptr, context.getSettingsRef()); +} void AddingDefaultBlockOutputStream::write(const Block & block) { @@ -38,7 +61,7 @@ void AddingDefaultBlockOutputStream::write(const Block & block) { const auto & elem = res.getByPosition(i); - if (const ColumnArray * array = typeid_cast(&*elem.column)) + if (const auto * array = typeid_cast(&*elem.column)) { String offsets_name = Nested::extractTableName(elem.name); auto & offsets_column = offset_columns[offsets_name]; diff --git a/dbms/src/DataStreams/AddingDefaultBlockOutputStream.h b/dbms/src/DataStreams/AddingDefaultBlockOutputStream.h index d1cb18487ce..d8d1e1ef50f 100644 --- a/dbms/src/DataStreams/AddingDefaultBlockOutputStream.h +++ b/dbms/src/DataStreams/AddingDefaultBlockOutputStream.h @@ -19,6 +19,7 @@ #include #include #include +#include namespace DB @@ -32,17 +33,12 @@ class AddingDefaultBlockOutputStream : public IBlockOutputStream { public: AddingDefaultBlockOutputStream( - const BlockOutputStreamPtr & output_, + const StoragePtr & storage_, + const ASTPtr & query_ptr_, const Block & header_, NamesAndTypesList required_columns_, const ColumnDefaults & column_defaults_, - const Context & context_) - : output(output_) - , header(header_) - , required_columns(required_columns_) - , column_defaults(column_defaults_) - , context(context_) - {} + const Context & context_); Block getHeader() const override { return header; } void write(const Block & block) override; @@ -53,11 +49,13 @@ class AddingDefaultBlockOutputStream : public IBlockOutputStream void writeSuffix() override; private: + StoragePtr storage; BlockOutputStreamPtr output; Block header; NamesAndTypesList required_columns; const ColumnDefaults column_defaults; const Context & context; + ASTPtr query_ptr; }; diff --git a/dbms/src/DataStreams/FilterTransformAction.cpp b/dbms/src/DataStreams/FilterTransformAction.cpp index 8a366d51ec4..f6860318b9b 100644 --- a/dbms/src/DataStreams/FilterTransformAction.cpp +++ b/dbms/src/DataStreams/FilterTransformAction.cpp @@ -102,16 +102,18 @@ bool FilterTransformAction::transform(Block & block, FilterPtr & res_filter, boo * and now - are calculated. That is, not all cases are covered by the code above. * This happens if the function returns a constant for a non-constant argument. * For example, `ignore` function. + * use a local variable to avoid the case that function return a constant for a non-constant argument in one block + * but return a non-constant for the same argument in another block. */ - constant_filter_description = ConstantFilterDescription(*column_of_filter); + auto current_constant_filter_description = ConstantFilterDescription(*column_of_filter); - if (constant_filter_description.always_false) + if (current_constant_filter_description.always_false) { block.clear(); return true; } - if (constant_filter_description.always_true) + if (current_constant_filter_description.always_true) { if (return_filter) res_filter = nullptr; @@ -140,7 +142,7 @@ bool FilterTransformAction::transform(Block & block, FilterPtr & res_filter, boo if (filtered_rows == rows) { /// Replace the column with the filter by a constant. - auto filter_column = block.safeGetByPosition(filter_column_position); + auto & filter_column = block.safeGetByPosition(filter_column_position); filter_column.column = filter_column.type->createColumnConst(filtered_rows, static_cast(1)); /// No need to touch the rest of the columns. return true; diff --git a/dbms/src/DataStreams/MergeSortingBlockInputStream.cpp b/dbms/src/DataStreams/MergeSortingBlockInputStream.cpp index ca8418faf64..cbf788057a4 100644 --- a/dbms/src/DataStreams/MergeSortingBlockInputStream.cpp +++ b/dbms/src/DataStreams/MergeSortingBlockInputStream.cpp @@ -88,6 +88,11 @@ Block MergeSortingBlockInputStream::readImpl() return block; SortHelper::removeConstantsFromBlock(block); + RUNTIME_CHECK_MSG( + block.columns() == header_without_constants.columns(), + "Unexpected number of constant columns in block in MergeSortingBlockInputStream, n_block={}, n_head={}", + block.columns(), + header_without_constants.columns()); blocks.push_back(block); sum_bytes_in_blocks += block.estimateBytesForSpill(); diff --git a/dbms/src/DataStreams/PushingToViewsBlockOutputStream.cpp b/dbms/src/DataStreams/PushingToViewsBlockOutputStream.cpp deleted file mode 100644 index 513753102be..00000000000 --- a/dbms/src/DataStreams/PushingToViewsBlockOutputStream.cpp +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - - -namespace DB -{ -PushingToViewsBlockOutputStream::PushingToViewsBlockOutputStream( - const String & database, - const String & table, - const StoragePtr & storage, - const Context & context_, - const ASTPtr & query_ptr_, - bool no_destination) - : context(context_) - , query_ptr(query_ptr_) -{ - /** TODO This is a very important line. At any insertion into the table one of streams should own lock. - * Although now any insertion into the table is done via PushingToViewsBlockOutputStream, - * but it's clear that here is not the best place for this functionality. - */ - addTableLock(storage->lockForShare(context.getCurrentQueryId())); - - if (!table.empty()) - { - Dependencies dependencies = context.getDependencies(database, table); - - RUNTIME_CHECK_MSG(dependencies.empty(), "Do not support ClickHouse's materialized view"); - } - - /* Do not push to destination table if the flag is set */ - if (!no_destination) - { - output = storage->write(query_ptr, context.getSettingsRef()); - } -} - - -void PushingToViewsBlockOutputStream::write(const Block & block) -{ - if (output) - output->write(block); -} - -} // namespace DB diff --git a/dbms/src/DataStreams/PushingToViewsBlockOutputStream.h b/dbms/src/DataStreams/PushingToViewsBlockOutputStream.h deleted file mode 100644 index 576a63e5262..00000000000 --- a/dbms/src/DataStreams/PushingToViewsBlockOutputStream.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include - - -namespace DB -{ - - -/** Writes data to the specified table and to all dependent materialized views. - */ -class PushingToViewsBlockOutputStream : public IBlockOutputStream -{ -public: - PushingToViewsBlockOutputStream( - const String & database, - const String & table, - const StoragePtr & storage, - const Context & context_, - const ASTPtr & query_ptr_, - bool no_destination = false); - - Block getHeader() const override { return storage->getSampleBlock(); } - void write(const Block & block) override; - - void flush() override - { - if (output) - output->flush(); - } - - void writePrefix() override - { - if (output) - output->writePrefix(); - } - - void writeSuffix() override - { - if (output) - output->writeSuffix(); - } - -private: - StoragePtr storage; - BlockOutputStreamPtr output; - - const Context & context; - ASTPtr query_ptr; -}; - - -} // namespace DB diff --git a/dbms/src/DataStreams/WindowTransformAction.cpp b/dbms/src/DataStreams/WindowTransformAction.cpp index f981bf017f9..9b145012de9 100644 --- a/dbms/src/DataStreams/WindowTransformAction.cpp +++ b/dbms/src/DataStreams/WindowTransformAction.cpp @@ -312,8 +312,8 @@ void WindowTransformAction::initialAggregateFunction( workspace.argument_columns.assign(workspace.arguments.size(), nullptr); workspace.aggregate_function = window_function_description.aggregate_function; const auto & aggregate_function = workspace.aggregate_function; - if (aggregate_function->allocatesMemoryInArena()) - throw Exception("arena is not supported now"); + if (!arena) + arena = std::make_unique(); workspace.aggregate_function_state.reset(aggregate_function->sizeOfData(), aggregate_function->alignOfData()); aggregate_function->create(workspace.aggregate_function_state.data()); @@ -1290,7 +1290,7 @@ void WindowTransformAction::writeOutCurrentRow() IColumn * result_column = block.output_columns[ws.idx].get(); const auto * agg_func = ws.aggregate_function.get(); auto * buf = ws.aggregate_function_state.data(); - agg_func->insertResultInto(buf, *result_column, nullptr); + agg_func->insertResultInto(buf, *result_column, arena.get()); } } diff --git a/dbms/src/DataStreams/WindowTransformAction.h b/dbms/src/DataStreams/WindowTransformAction.h index bc1db118817..7fb00aec134 100644 --- a/dbms/src/DataStreams/WindowTransformAction.h +++ b/dbms/src/DataStreams/WindowTransformAction.h @@ -213,12 +213,12 @@ struct WindowTransformAction if constexpr (is_add) { - agg_func->addBatchSinglePlace(start_row, end_row - start_row, buf, columns, nullptr); + agg_func->addBatchSinglePlace(start_row, end_row - start_row, buf, columns, arena.get()); } else { for (auto row = start_row; row < end_row; ++row) - agg_func->decrease(buf, columns, row, nullptr); + agg_func->decrease(buf, columns, row, arena.get()); } } } @@ -318,5 +318,7 @@ struct WindowTransformAction RowNumber prev_frame_end; bool has_rank_or_dense_rank = false; + + std::unique_ptr arena; }; } // namespace DB diff --git a/dbms/src/Debug/MockKVStore/MockProxyRegion.cpp b/dbms/src/Debug/MockKVStore/MockProxyRegion.cpp index 5b390669ef3..2ac3b9e341c 100644 --- a/dbms/src/Debug/MockKVStore/MockProxyRegion.cpp +++ b/dbms/src/Debug/MockKVStore/MockProxyRegion.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -22,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -172,4 +172,4 @@ void MockProxyRegion::addPeer(uint64_t store_id, uint64_t peer_id, metapb::PeerR peer.set_id(peer_id); peer.set_role(role); } -} // namespace DB \ No newline at end of file +} // namespace DB diff --git a/dbms/src/Debug/MockKVStore/MockRaftStoreProxy.cpp b/dbms/src/Debug/MockKVStore/MockRaftStoreProxy.cpp index 12c69e1a977..fd058d23186 100644 --- a/dbms/src/Debug/MockKVStore/MockRaftStoreProxy.cpp +++ b/dbms/src/Debug/MockKVStore/MockRaftStoreProxy.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -28,10 +29,10 @@ #include #include #include +#include #include #include #include -#include #include #include #include @@ -39,11 +40,6 @@ namespace DB { -namespace RegionBench -{ -extern void setupPutRequest(raft_cmdpb::Request *, const std::string &, const TiKVKey &, const TiKVValue &); -extern void setupDelRequest(raft_cmdpb::Request *, const std::string &, const TiKVKey &); -} // namespace RegionBench TiFlashRaftProxyHelper MockRaftStoreProxy::setRaftStoreProxyFFIHelper(RaftStoreProxyPtr proxy_ptr) { @@ -180,7 +176,11 @@ void MockRaftStoreProxy::debugAddRegions( auto lock = kvs.genRegionMgrWriteLock(task_lock); // Region mgr lock for (int i = 0; i < n; ++i) { - auto region = tests::makeRegion(region_ids[i], ranges[i].first, ranges[i].second, kvs.getProxyHelper()); + auto region = RegionBench::makeRegionForRange( + region_ids[i], + ranges[i].first, + ranges[i].second, + kvs.getProxyHelper()); lock.regions.emplace(region_ids[i], region); lock.index.add(region); tmt.getRegionTable().addRegion(*region); diff --git a/dbms/src/Debug/MockKVStore/MockRaftStoreProxy.h b/dbms/src/Debug/MockKVStore/MockRaftStoreProxy.h index bcabd825a75..28ab70ee3ce 100644 --- a/dbms/src/Debug/MockKVStore/MockRaftStoreProxy.h +++ b/dbms/src/Debug/MockKVStore/MockRaftStoreProxy.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include diff --git a/dbms/src/Debug/MockKVStore/MockReadIndex.cpp b/dbms/src/Debug/MockKVStore/MockReadIndex.cpp index 76858b658dd..bf5951392dd 100644 --- a/dbms/src/Debug/MockKVStore/MockReadIndex.cpp +++ b/dbms/src/Debug/MockKVStore/MockReadIndex.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -30,7 +31,6 @@ #include #include #include -#include #include #include #include diff --git a/dbms/src/Debug/MockKVStore/MockSSTReader.h b/dbms/src/Debug/MockKVStore/MockSSTReader.h index 30fc6dad50c..25ffee2f674 100644 --- a/dbms/src/Debug/MockKVStore/MockSSTReader.h +++ b/dbms/src/Debug/MockKVStore/MockSSTReader.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -27,8 +28,6 @@ class KVStore; using KVStorePtr = std::shared_ptr; class RegionTable; -class Region; -using RegionPtr = std::shared_ptr; /// Some helper structure / functions for IngestSST diff --git a/dbms/src/Debug/MockKVStore/MockTiKV.cpp b/dbms/src/Debug/MockKVStore/MockTiKV.cpp new file mode 100644 index 00000000000..d0d82165395 --- /dev/null +++ b/dbms/src/Debug/MockKVStore/MockTiKV.cpp @@ -0,0 +1,69 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +namespace DB +{ +RegionPtr MockTiKV::createRegion(TableID table_id, RegionID region_id, const HandleID & start, const HandleID & end) +{ + // peer_id is a fake number here + auto meta_region = RegionBench::createMetaRegion(region_id, table_id, start, end); + metapb::Peer peer; + RegionMeta region_meta(std::move(peer), std::move(meta_region), initialApplyState()); + // set the index according to mock tikv + UInt64 index = getNextRaftIndex(region_id); + region_meta.setApplied(index, RAFT_INIT_LOG_TERM); + return RegionBench::makeRegion(std::move(region_meta)); +} + +RegionPtr MockTiKV::createRegionCommonHandle( + const TiDB::TableInfo & table_info, + RegionID region_id, + std::vector & start_keys, + std::vector & end_keys) +{ + metapb::Region region = RegionBench::createMetaRegionCommonHandle( + region_id, + RecordKVFormat::genKey(table_info, start_keys), + RecordKVFormat::genKey(table_info, end_keys)); + + metapb::Peer peer; + RegionMeta region_meta(std::move(peer), std::move(region), initialApplyState()); + // set the index according to mock tikv + auto index = getNextRaftIndex(region_id); + region_meta.setApplied(index, RAFT_INIT_LOG_TERM); + return RegionBench::makeRegion(std::move(region_meta)); +} + +Regions MockTiKV::createRegions( + TableID table_id, + size_t region_num, + size_t key_num_each_region, + HandleID handle_begin, + RegionID new_region_id_begin) +{ + Regions regions; + for (RegionID region_id = new_region_id_begin; region_id < static_cast(new_region_id_begin + region_num); + ++region_id, handle_begin += key_num_each_region) + { + auto ptr = createRegion(table_id, region_id, handle_begin, handle_begin + key_num_each_region); + regions.push_back(ptr); + } + return regions; +} +} // namespace DB diff --git a/dbms/src/Debug/MockKVStore/MockTiKV.h b/dbms/src/Debug/MockKVStore/MockTiKV.h index 6e17370e655..647cb3a6cf0 100644 --- a/dbms/src/Debug/MockKVStore/MockTiKV.h +++ b/dbms/src/Debug/MockKVStore/MockTiKV.h @@ -1,4 +1,4 @@ -// Copyright 2023 PingCAP, Inc. +// Copyright 2025 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,15 +15,43 @@ #pragma once #include +#include +#include namespace DB { + class MockTiKV : public ext::Singleton { friend class ext::Singleton; public: - UInt64 getRaftIndex(RegionID region_id) + // Generate a RegionPtr with given params. + // The raft-index is set according to the mock raft-index on `MockTiKV` instance. + RegionPtr createRegion( // + TableID table_id, + RegionID region_id, + const HandleID & start, + const HandleID & end); + + // Generate a RegionPtr with given params of common handle. + // The raft-index is set according to the mock raft-index on `MockTiKV` instance. + RegionPtr createRegionCommonHandle( + const TiDB::TableInfo & table_info, + RegionID region_id, + std::vector & start_keys, + std::vector & end_keys); + + // Generate multiple RegionPtrs with given params. + // Each Region's raft-index is set according to the mock raft-index on `MockTiKV` instance + Regions createRegions( + TableID table_id, + size_t region_num, + size_t key_num_each_region, + HandleID handle_begin, + RegionID new_region_id_begin); + + UInt64 getNextRaftIndex(RegionID region_id) { std::lock_guard lock(mutex); auto it = raft_index.find(region_id); diff --git a/dbms/src/Debug/MockKVStore/MockUtils.cpp b/dbms/src/Debug/MockKVStore/MockUtils.cpp new file mode 100644 index 00000000000..85f884cc557 --- /dev/null +++ b/dbms/src/Debug/MockKVStore/MockUtils.cpp @@ -0,0 +1,897 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + + +namespace DB::ErrorCodes +{ +extern const int LOGICAL_ERROR; +extern const int UNKNOWN_TABLE; +extern const int UNKNOWN_DATABASE; +} // namespace DB::ErrorCodes + +namespace DB::RegionBench +{ +metapb::Peer createPeer(UInt64 id, bool) +{ + metapb::Peer peer; + peer.set_id(id); + return peer; +} + +metapb::Region createMetaRegion( + RegionID region_id, + TableID table_id, + HandleID start, + HandleID end, + std::optional maybe_epoch, + std::optional> maybe_peers) +{ + TiKVKey start_key = RecordKVFormat::genKey(table_id, start); + TiKVKey end_key = RecordKVFormat::genKey(table_id, end); + + return createMetaRegionCommonHandle(region_id, start_key, end_key, maybe_epoch, maybe_peers); +} + +metapb::Region createMetaRegionCommonHandle( + RegionID region_id, + const std::string & start_key, + const std::string & end_key, + std::optional maybe_epoch, + std::optional> maybe_peers) +{ + metapb::Region meta; + meta.set_id(region_id); + + meta.set_start_key(start_key); + meta.set_end_key(end_key); + + if (maybe_epoch) + { + *meta.mutable_region_epoch() = maybe_epoch.value(); + } + else + { + meta.mutable_region_epoch()->set_version(5); + meta.mutable_region_epoch()->set_conf_ver(6); + } + + if (maybe_peers) + { + const auto & peers = maybe_peers.value(); + for (const auto & peer : peers) + { + *(meta.mutable_peers()->Add()) = peer; + } + } + else + { + *(meta.mutable_peers()->Add()) = createPeer(1, true); + *(meta.mutable_peers()->Add()) = createPeer(2, false); + } + + return meta; +} + + +RegionPtr makeRegion(RegionMeta && meta) +{ + return std::make_shared(std::move(meta), nullptr); +} + +RegionPtr makeRegionForRange( + UInt64 id, + std::string start_key, + std::string end_key, + const TiFlashRaftProxyHelper * proxy_helper) +{ + return std::make_shared( + RegionMeta( + createPeer(2, true), + createMetaRegionCommonHandle(id, std::move(start_key), std::move(end_key)), + initialApplyState()), + proxy_helper); +} + +RegionPtr makeRegionForTable( + UInt64 region_id, + TableID table_id, + HandleID start, + HandleID end, + const TiFlashRaftProxyHelper * proxy_helper) +{ + return makeRegionForRange( + region_id, + RecordKVFormat::genKey(table_id, start).toString(), + RecordKVFormat::genKey(table_id, end).toString(), + proxy_helper); +} + +// Generates a lock value which fills all fields, only for test use. +TiKVValue encodeFullLockCfValue( + UInt8 lock_type, + const String & primary, + Timestamp ts, + UInt64 ttl, + const String * short_value, + Timestamp min_commit_ts, + Timestamp for_update_ts, + uint64_t txn_size, + const std::vector & async_commit, + const std::vector & rollback, + UInt64 generation) +{ + auto lock_value = RecordKVFormat::encodeLockCfValue(lock_type, primary, ts, ttl, short_value, min_commit_ts); + WriteBufferFromOwnString res; + res.write(lock_value.getStr().data(), lock_value.getStr().size()); + { + res.write(RecordKVFormat::MIN_COMMIT_TS_PREFIX); + RecordKVFormat::encodeUInt64(min_commit_ts, res); + } + { + res.write(RecordKVFormat::FOR_UPDATE_TS_PREFIX); + RecordKVFormat::encodeUInt64(for_update_ts, res); + } + { + res.write(RecordKVFormat::TXN_SIZE_PREFIX); + RecordKVFormat::encodeUInt64(txn_size, res); + } + { + res.write(RecordKVFormat::ROLLBACK_TS_PREFIX); + TiKV::writeVarUInt(rollback.size(), res); + for (auto ts : rollback) + { + RecordKVFormat::encodeUInt64(ts, res); + } + } + { + res.write(RecordKVFormat::ASYNC_COMMIT_PREFIX); + TiKV::writeVarUInt(async_commit.size(), res); + for (const auto & s : async_commit) + { + writeVarInt(s.size(), res); + res.write(s.data(), s.size()); + } + } + { + res.write(RecordKVFormat::LAST_CHANGE_PREFIX); + RecordKVFormat::encodeUInt64(12345678, res); + TiKV::writeVarUInt(87654321, res); + } + { + res.write(RecordKVFormat::TXN_SOURCE_PREFIX_FOR_LOCK); + TiKV::writeVarUInt(876543, res); + } + { + res.write(RecordKVFormat::PESSIMISTIC_LOCK_WITH_CONFLICT_PREFIX); + } + if (generation > 0) + { + res.write(RecordKVFormat::GENERATION_PREFIX); + RecordKVFormat::encodeUInt64(generation, res); + } + return TiKVValue(res.releaseStr()); +} + +using TiDB::ColumnInfo; + +void setupPutRequest(raft_cmdpb::Request * req, const std::string & cf, const TiKVKey & key, const TiKVValue & value) +{ + req->set_cmd_type(raft_cmdpb::CmdType::Put); + raft_cmdpb::PutRequest * put = req->mutable_put(); + put->set_cf(cf.c_str()); + put->set_key(key.getStr()); + put->set_value(value.getStr()); +} + +void setupDelRequest(raft_cmdpb::Request * req, const std::string & cf, const TiKVKey & key) +{ + req->set_cmd_type(raft_cmdpb::CmdType::Delete); + raft_cmdpb::DeleteRequest * del = req->mutable_delete_(); + del->set_cf(cf.c_str()); + del->set_key(key.getStr()); +} + +void addRequestsToRaftCmd( + raft_cmdpb::RaftCmdRequest & request, + const TiKVKey & key, + const TiKVValue & value, + UInt64 prewrite_ts, + UInt64 commit_ts, + bool del, + const String & pk) +{ + TiKVKey commit_key = RecordKVFormat::appendTs(key, commit_ts); + const TiKVKey & lock_key = key; + + if (del) + { + TiKVValue lock_value = RecordKVFormat::encodeLockCfValue(Region::DelFlag, pk, prewrite_ts, 0); + TiKVValue commit_value = RecordKVFormat::encodeWriteCfValue(Region::DelFlag, prewrite_ts); + + setupPutRequest(request.add_requests(), ColumnFamilyName::Lock, lock_key, lock_value); + setupPutRequest(request.add_requests(), ColumnFamilyName::Write, commit_key, commit_value); + setupDelRequest(request.add_requests(), ColumnFamilyName::Lock, lock_key); + return; + } + + if (value.dataSize() <= RecordKVFormat::SHORT_VALUE_MAX_LEN) + { + TiKVValue lock_value = RecordKVFormat::encodeLockCfValue(Region::PutFlag, pk, prewrite_ts, 0); + + TiKVValue commit_value = RecordKVFormat::encodeWriteCfValue(Region::PutFlag, prewrite_ts, value.toString()); + + setupPutRequest(request.add_requests(), ColumnFamilyName::Lock, lock_key, lock_value); + setupPutRequest(request.add_requests(), ColumnFamilyName::Write, commit_key, commit_value); + setupDelRequest(request.add_requests(), ColumnFamilyName::Lock, lock_key); + } + else + { + TiKVValue lock_value = RecordKVFormat::encodeLockCfValue(Region::PutFlag, pk, prewrite_ts, 0); + + TiKVKey prewrite_key = RecordKVFormat::appendTs(key, prewrite_ts); + const TiKVValue & prewrite_value = value; + + TiKVValue commit_value = RecordKVFormat::encodeWriteCfValue(Region::PutFlag, prewrite_ts); + + setupPutRequest(request.add_requests(), ColumnFamilyName::Lock, lock_key, lock_value); + setupPutRequest(request.add_requests(), ColumnFamilyName::Write, commit_key, commit_value); + setupPutRequest(request.add_requests(), ColumnFamilyName::Default, prewrite_key, prewrite_value); + setupDelRequest(request.add_requests(), ColumnFamilyName::Lock, lock_key); + } +} + +template +T convertNumber(const Field & field) +{ + switch (field.getType()) + { + case Field::Types::Int64: + return static_cast(field.get()); + case Field::Types::UInt64: + return static_cast(field.get()); + case Field::Types::Float64: + return static_cast(field.get()); + case Field::Types::Decimal32: + return static_cast(field.get>()); + case Field::Types::Decimal64: + return static_cast(field.get>()); + case Field::Types::Decimal128: + return static_cast(field.get>()); + case Field::Types::Decimal256: + return static_cast(field.get>()); + default: + throw Exception( + String("Unable to convert field type ") + field.getTypeName() + " to number", + ErrorCodes::LOGICAL_ERROR); + } +} + +Field convertDecimal(const ColumnInfo & column_info, const Field & field) +{ + switch (field.getType()) + { + case Field::Types::Int64: + return column_info.getDecimalValue(std::to_string(field.get())); + case Field::Types::UInt64: + return column_info.getDecimalValue(std::to_string(field.get())); + case Field::Types::Float64: + return column_info.getDecimalValue(std::to_string(field.get())); + case Field::Types::Decimal32: + return column_info.getDecimalValue(field.get().toString(column_info.decimal)); + case Field::Types::Decimal64: + return column_info.getDecimalValue(field.get().toString(column_info.decimal)); + case Field::Types::Decimal128: + return column_info.getDecimalValue(field.get().toString(column_info.decimal)); + case Field::Types::Decimal256: + return column_info.getDecimalValue(field.get().toString(column_info.decimal)); + default: + throw Exception( + String("Unable to convert field type ") + field.getTypeName() + " to number", + ErrorCodes::LOGICAL_ERROR); + } +} + +Field convertEnum(const ColumnInfo & column_info, const Field & field) +{ + switch (field.getType()) + { + case Field::Types::Int64: + case Field::Types::UInt64: + return convertNumber(field); + case Field::Types::String: + return static_cast(column_info.getEnumIndex(field.get())); + default: + throw Exception( + String("Unable to convert field type ") + field.getTypeName() + " to Enum", + ErrorCodes::LOGICAL_ERROR); + } +} + +Field convertField(const ColumnInfo & column_info, const Field & field) +{ + if (field.isNull()) + return field; + + switch (column_info.tp) + { + case TiDB::TypeTiny: + case TiDB::TypeShort: + case TiDB::TypeLong: + case TiDB::TypeLongLong: + case TiDB::TypeInt24: + if (column_info.hasUnsignedFlag()) + return convertNumber(field); + else + return convertNumber(field); + case TiDB::TypeFloat: + case TiDB::TypeDouble: + return convertNumber(field); + case TiDB::TypeDate: + case TiDB::TypeDatetime: + case TiDB::TypeTimestamp: + return parseMyDateTime(field.safeGet()); + case TiDB::TypeVarchar: + case TiDB::TypeTinyBlob: + case TiDB::TypeMediumBlob: + case TiDB::TypeLongBlob: + case TiDB::TypeBlob: + case TiDB::TypeVarString: + case TiDB::TypeString: + return field; + case TiDB::TypeEnum: + return convertEnum(column_info, field); + case TiDB::TypeNull: + return Field(); + case TiDB::TypeDecimal: + case TiDB::TypeNewDecimal: + return convertDecimal(column_info, field); + case TiDB::TypeTime: + case TiDB::TypeYear: + return convertNumber(field); + case TiDB::TypeSet: + case TiDB::TypeBit: + return convertNumber(field); + default: + return Field(); + } +} + +void encodeRow(const TiDB::TableInfo & table_info, const std::vector & fields, WriteBuffer & ss) +{ + if (table_info.columns.size() < fields.size() + table_info.pk_is_handle) + throw Exception( + "Encoding row has less columns than encode values [num_columns=" + DB::toString(table_info.columns.size()) + + "] [num_fields=" + DB::toString(fields.size()) + "] . ", + ErrorCodes::LOGICAL_ERROR); + + std::vector flatten_fields; + std::unordered_set pk_column_names; + if (table_info.is_common_handle) + { + for (const auto & idx_col : table_info.getPrimaryIndexInfo().idx_cols) + { + // todo support prefix index + pk_column_names.insert(idx_col.name); + } + } + for (size_t i = 0; i < fields.size(); i++) + { + const auto & column_info = table_info.columns[i]; + /// skip the columns encoded in the key + if (pk_column_names.find(column_info.name) != pk_column_names.end()) + continue; + Field field = convertField(column_info, fields[i]); + TiDB::DatumBumpy datum = TiDB::DatumBumpy(field, column_info.tp); + flatten_fields.emplace_back(datum.field()); + } + + static bool row_format_flip = false; + // Ping-pong encoding using row format V1/V2. + (row_format_flip = !row_format_flip) ? encodeRowV1(table_info, flatten_fields, ss) + : encodeRowV2(table_info, flatten_fields, ss); +} + +void insert( // + const TiDB::TableInfo & table_info, + RegionID region_id, + HandleID handle_id, // + ASTs::const_iterator values_begin, + ASTs::const_iterator values_end, // + Context & context, + const std::optional> & tso_del) +{ + // Parse the fields in the inserted row + std::vector fields; + { + for (auto it = values_begin; it != values_end; ++it) + { + auto field = typeid_cast((*it).get())->value; + fields.emplace_back(field); + } + if (fields.size() + table_info.pk_is_handle != table_info.columns.size()) + throw Exception("Number of insert values and columns do not match.", ErrorCodes::LOGICAL_ERROR); + } + TMTContext & tmt = context.getTMTContext(); + pingcap::pd::ClientPtr pd_client = tmt.getPDClient(); + RegionPtr region = tmt.getKVStore()->getRegion(region_id); + + // Using the region meta's table ID rather than table_info's, as this could be a partition table so that the table ID should be partition ID. + const auto range = region->getRange(); + TableID table_id = RecordKVFormat::getTableId(*range->rawKeys().first); + + TiKVKey key; + if (table_info.is_common_handle) + { + std::vector keys; + const auto & pk_index = table_info.getPrimaryIndexInfo(); + for (const auto & idx_col : pk_index.idx_cols) + { + const auto & column_info = table_info.columns[idx_col.offset]; + auto start_field = RegionBench::convertField(column_info, fields[idx_col.offset]); + TiDB::DatumBumpy start_datum = TiDB::DatumBumpy(start_field, column_info.tp); + keys.emplace_back(start_datum.field()); + } + key = RecordKVFormat::genKey(table_info, keys); + } + else + key = RecordKVFormat::genKey(table_id, handle_id); + WriteBufferFromOwnString ss; + encodeRow(table_info, fields, ss); + TiKVValue value(ss.releaseStr()); + + UInt64 prewrite_ts = pd_client->getTS(); + UInt64 commit_ts = pd_client->getTS(); + bool is_del = false; + + if (tso_del.has_value()) + { + auto [tso, del] = *tso_del; + prewrite_ts = tso; + commit_ts = tso; + is_del = del; + } + + raft_cmdpb::RaftCmdRequest request; + addRequestsToRaftCmd(request, key, value, prewrite_ts, commit_ts, is_del); + RegionBench::applyWriteRaftCmd( + *tmt.getKVStore(), + std::move(request), + region_id, + MockTiKV::instance().getNextRaftIndex(region_id), + MockTiKV::instance().getRaftTerm(region_id), + tmt); +} + +void remove(const TiDB::TableInfo & table_info, RegionID region_id, HandleID handle_id, Context & context) +{ + static const TiKVValue value; + + TiKVKey key = RecordKVFormat::genKey(table_info.id, handle_id); + + TMTContext & tmt = context.getTMTContext(); + pingcap::pd::ClientPtr pd_client = tmt.getPDClient(); + RegionPtr region = tmt.getKVStore()->getRegion(region_id); + + UInt64 prewrite_ts = pd_client->getTS(); + UInt64 commit_ts = pd_client->getTS(); + + raft_cmdpb::RaftCmdRequest request; + addRequestsToRaftCmd(request, key, value, prewrite_ts, commit_ts, true); + RegionBench::applyWriteRaftCmd( + *tmt.getKVStore(), + std::move(request), + region_id, + MockTiKV::instance().getNextRaftIndex(region_id), + MockTiKV::instance().getRaftTerm(region_id), + tmt); +} + +struct BatchCtrl +{ + String default_str; + Int64 concurrent_id; + Int64 flush_num; + Int64 batch_num; + UInt64 min_strlen; + UInt64 max_strlen; + Context * context; + RegionPtr region; + HandleID handle_begin; + bool del; + + BatchCtrl( + Int64 concurrent_id_, + Int64 flush_num_, + Int64 batch_num_, + UInt64 min_strlen_, + UInt64 max_strlen_, + Context * context_, + RegionPtr region_, + HandleID handle_begin_, + bool del_) + : concurrent_id(concurrent_id_) + , flush_num(flush_num_) + , batch_num(batch_num_) + , min_strlen(min_strlen_) + , max_strlen(max_strlen_) + , context(context_) + , region(region_) + , handle_begin(handle_begin_) + , del(del_) + { + assert(max_strlen >= min_strlen); + assert(min_strlen >= 1); + auto str_len = static_cast(random() % (max_strlen - min_strlen + 1) + min_strlen); + default_str = String(str_len, '_'); + } + + void encodeDatum(WriteBuffer & ss, TiDB::CodecFlag flag, Int64 magic_num) + { + Int8 target = (magic_num % 70) + '0'; + EncodeUInt(static_cast(flag), ss); + switch (flag) + { + case TiDB::CodecFlagJson: + throw Exception( + "Not implemented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagJson", + ErrorCodes::LOGICAL_ERROR); + case TiDB::CodecFlagVectorFloat32: + throw Exception( + "Not implemented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagVectorFloat32", + ErrorCodes::LOGICAL_ERROR); + case TiDB::CodecFlagMax: + throw Exception( + "Not implemented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagMax", + ErrorCodes::LOGICAL_ERROR); + case TiDB::CodecFlagDuration: + throw Exception( + "Not implemented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagDuration", + ErrorCodes::LOGICAL_ERROR); + case TiDB::CodecFlagNil: + return; + case TiDB::CodecFlagBytes: + memset(default_str.data(), target, default_str.size()); + return EncodeBytes(default_str, ss); + //case TiDB::CodecFlagDecimal: + // return EncodeDecimal(Decimal(magic_num), ss); + case TiDB::CodecFlagCompactBytes: + memset(default_str.data(), target, default_str.size()); + return EncodeCompactBytes(default_str, ss); + case TiDB::CodecFlagFloat: + return EncodeFloat64(static_cast(magic_num) / 1111.1, ss); + case TiDB::CodecFlagUInt: + return EncodeUInt(static_cast(magic_num), ss); + case TiDB::CodecFlagInt: + return EncodeInt64((magic_num), ss); + case TiDB::CodecFlagVarInt: + return EncodeVarInt((magic_num), ss); + case TiDB::CodecFlagVarUInt: + return EncodeVarUInt(static_cast(magic_num), ss); + default: + throw Exception("Not implemented codec flag: " + std::to_string(flag), ErrorCodes::LOGICAL_ERROR); + } + } + + TiKVValue encodeRow(const TiDB::TableInfo & table_info, Int64 magic_num) + { + WriteBufferFromOwnString ss; + for (const auto & column : table_info.columns) + { + encodeDatum(ss, TiDB::CodecFlagInt, column.id); + // TODO: May need to use BumpyDatum to flatten before encoding. + encodeDatum(ss, column.getCodecFlag(), magic_num); + } + return TiKVValue(ss.releaseStr()); + } +}; + +void batchInsert( + const TiDB::TableInfo & table_info, + std::unique_ptr batch_ctrl, + std::function fn_gen_magic_num) +{ + RegionPtr & region = batch_ctrl->region; + + TMTContext & tmt = batch_ctrl->context->getTMTContext(); + pingcap::pd::ClientPtr pd_client = tmt.getPDClient(); + + Int64 index = batch_ctrl->handle_begin; + + for (Int64 flush_cnt = 0; flush_cnt < batch_ctrl->flush_num; ++flush_cnt) + { + UInt64 prewrite_ts = pd_client->getTS(); + UInt64 commit_ts = pd_client->getTS(); + + raft_cmdpb::RaftCmdRequest request; + + for (Int64 cnt = 0; cnt < batch_ctrl->batch_num; ++index, ++cnt) + { + TiKVKey key = RecordKVFormat::genKey(table_info.id, index); + TiKVValue value = batch_ctrl->encodeRow(table_info, fn_gen_magic_num(index)); + addRequestsToRaftCmd(request, key, value, prewrite_ts, commit_ts, batch_ctrl->del); + } + + RegionBench::applyWriteRaftCmd( + *tmt.getKVStore(), + std::move(request), + region->id(), + MockTiKV::instance().getNextRaftIndex(region->id()), + MockTiKV::instance().getRaftTerm(region->id()), + tmt); + } +} + +void concurrentBatchInsert( + const TiDB::TableInfo & table_info, + Int64 concurrent_num, + Int64 flush_num, + Int64 batch_num, + UInt64 min_strlen, + UInt64 max_strlen, + Context & context) +{ + TMTContext & tmt = context.getTMTContext(); + + RegionID curr_max_region_id(InvalidRegionID); + HandleID curr_max_handle_id = 0; + tmt.getKVStore()->traverseRegions([&](const RegionID region_id, const RegionPtr & region) { + curr_max_region_id + = (curr_max_region_id == InvalidRegionID) ? region_id : std::max(curr_max_region_id, region_id); + const auto range = region->getRange(); + curr_max_handle_id = std::max(RecordKVFormat::getHandle(*range->rawKeys().second), curr_max_handle_id); + }); + + Int64 key_num_each_region = flush_num * batch_num; + HandleID handle_begin = curr_max_handle_id; + + + auto debug_kvstore = RegionBench::DebugKVStore(*tmt.getKVStore()); + Regions regions = MockTiKV::instance().createRegions( // + table_info.id, + concurrent_num, + key_num_each_region, + handle_begin, + curr_max_region_id + 1); + for (const RegionPtr & region : regions) + debug_kvstore.onSnapshot(RegionPtrWithSnapshotFiles{region, {}}, nullptr, 0, tmt); + + std::list threads; + for (Int64 i = 0; i < concurrent_num; i++, handle_begin += key_num_each_region) + { + auto batch_ptr = std::make_unique< + BatchCtrl>(i, flush_num, batch_num, min_strlen, max_strlen, &context, regions[i], handle_begin, false); + threads.push_back( + std::thread(&batchInsert, table_info, std::move(batch_ptr), [](Int64 index) -> Int64 { return index; })); + } + for (auto & thread : threads) + { + thread.join(); + } +} + +Int64 concurrentRangeOperate( + const TiDB::TableInfo & table_info, + HandleID start_handle, + HandleID end_handle, + Context & context, + Int64 magic_num, + bool del) +{ + Regions regions; + + { + TMTContext & tmt = context.getTMTContext(); + for (auto && [_, r] : tmt.getRegionTable().getRegionsByTable(NullspaceID, table_info.id)) + { + std::ignore = _; + if (r == nullptr) + continue; + regions.push_back(r); + } + } + + std::shuffle(regions.begin(), regions.end(), std::default_random_engine()); + + std::list threads; + Int64 tol = 0; + for (const auto & region : regions) + { + const auto range = region->getRange(); + const auto & [ss, ee] = getHandleRangeByTable(range->rawKeys(), table_info.id); + TiKVRange::Handle handle_begin = std::max(ss, start_handle); + TiKVRange::Handle handle_end = std::min(ee, end_handle); + if (handle_end <= handle_begin) + continue; + Int64 batch_num = handle_end - handle_begin; + tol += batch_num; + auto batch_ptr + = std::make_unique(-1, 1, batch_num, 1, 1, &context, region, handle_begin.handle_id, del); + threads.push_back(std::thread(&batchInsert, table_info, std::move(batch_ptr), [=](Int64 index) -> Int64 { + std::ignore = index; + return magic_num; + })); + } + for (auto & thread : threads) + { + thread.join(); + } + return tol; +} + +TableID getTableID( + Context & context, + const std::string & database_name, + const std::string & table_name, + const std::string & partition_id) +{ + try + { + using TablePtr = MockTiDB::TablePtr; + TablePtr table = MockTiDB::instance().getTableByName(database_name, table_name); + + if (table->isPartitionTable()) + return std::strtol(partition_id.c_str(), nullptr, 0); + + return table->id(); + } + catch (Exception & e) + { + if (e.code() != ErrorCodes::UNKNOWN_TABLE) + throw; + } + + auto mapped_table_name = mappedTable(context, database_name, table_name).second; + auto mapped_database_name = mappedDatabase(context, database_name); + auto storage = context.getTable(mapped_database_name, mapped_table_name); + auto managed_storage = std::static_pointer_cast(storage); + auto table_info = managed_storage->getTableInfo(); + return table_info.id; +} + +const TiDB::TableInfo & getTableInfo(Context & context, const String & database_name, const String & table_name) +{ + try + { + using TablePtr = MockTiDB::TablePtr; + TablePtr table = MockTiDB::instance().getTableByName(database_name, table_name); + + return table->table_info; + } + catch (Exception & e) + { + if (e.code() != ErrorCodes::UNKNOWN_TABLE) + throw; + } + + auto mapped_table_name = mappedTable(context, database_name, table_name).second; + auto mapped_database_name = mappedDatabase(context, database_name); + auto storage = context.getTable(mapped_database_name, mapped_table_name); + auto managed_storage = std::static_pointer_cast(storage); + return managed_storage->getTableInfo(); +} + + +EngineStoreApplyRes applyWriteRaftCmd( + KVStore & kvstore, + raft_cmdpb::RaftCmdRequest && request, + UInt64 region_id, + UInt64 index, + UInt64 term, + TMTContext & tmt, + DM::WriteResult * write_result_ptr) +{ + std::vector keys; + std::vector vals; + std::vector cmd_types; + std::vector cmd_cf; + keys.reserve(request.requests_size()); + vals.reserve(request.requests_size()); + cmd_types.reserve(request.requests_size()); + cmd_cf.reserve(request.requests_size()); + + for (const auto & req : request.requests()) + { + auto type = req.cmd_type(); + + switch (type) + { + case raft_cmdpb::CmdType::Put: + keys.push_back({req.put().key().data(), req.put().key().size()}); + vals.push_back({req.put().value().data(), req.put().value().size()}); + cmd_types.push_back(WriteCmdType::Put); + cmd_cf.push_back(NameToCF(req.put().cf())); + break; + case raft_cmdpb::CmdType::Delete: + keys.push_back({req.delete_().key().data(), req.delete_().key().size()}); + vals.push_back({nullptr, 0}); + cmd_types.push_back(WriteCmdType::Del); + cmd_cf.push_back(NameToCF(req.delete_().cf())); + break; + default: + throw Exception( + fmt::format("Unsupport raft cmd {}", raft_cmdpb::CmdType_Name(type)), + ErrorCodes::LOGICAL_ERROR); + } + } + if (write_result_ptr) + { + return kvstore.handleWriteRaftCmdInner( + WriteCmdsView{ + .keys = keys.data(), + .vals = vals.data(), + .cmd_types = cmd_types.data(), + .cmd_cf = cmd_cf.data(), + .len = keys.size()}, + region_id, + index, + term, + tmt, + *write_result_ptr); + } + else + { + DM::WriteResult write_result; + return kvstore.handleWriteRaftCmdInner( + WriteCmdsView{ + .keys = keys.data(), + .vals = vals.data(), + .cmd_types = cmd_types.data(), + .cmd_cf = cmd_cf.data(), + .len = keys.size()}, + region_id, + index, + term, + tmt, + write_result); + } +} + +void handleApplySnapshot( + KVStore & kvstore, + metapb::Region && region, + uint64_t peer_id, + SSTViewVec snaps, + uint64_t index, + uint64_t term, + std::optional deadline_index, + TMTContext & tmt) +{ + auto new_region = kvstore.genRegionPtr(std::move(region), peer_id, index, term, tmt.getRegionTable()); + auto prehandle_result = kvstore.preHandleSnapshotToFiles(new_region, snaps, index, term, deadline_index, tmt); + kvstore.applyPreHandledSnapshot( + RegionPtrWithSnapshotFiles{new_region, std::move(prehandle_result.ingest_ids)}, + tmt); +} +} // namespace DB::RegionBench diff --git a/dbms/src/Debug/MockKVStore/MockUtils.h b/dbms/src/Debug/MockKVStore/MockUtils.h index 69cc246245d..d8781b4a2d6 100644 --- a/dbms/src/Debug/MockKVStore/MockUtils.h +++ b/dbms/src/Debug/MockKVStore/MockUtils.h @@ -14,90 +14,60 @@ #pragma once +#include +#include +#include #include #include #include -namespace DB::RegionBench +namespace DB { +class KVStore; +class TMTContext; +} // namespace DB - -inline metapb::Peer createPeer(UInt64 id, bool) +namespace DB::RegionBench { - metapb::Peer peer; - peer.set_id(id); - return peer; -} +metapb::Peer createPeer(UInt64 id, bool); -inline metapb::Region createRegionInfo( - UInt64 id, - const std::string start_key, - const std::string end_key, +metapb::Region createMetaRegion( // + RegionID region_id, + TableID table_id, + HandleID start, + HandleID end, std::optional maybe_epoch = std::nullopt, - std::optional> maybe_peers = std::nullopt) -{ - metapb::Region region_info; - region_info.set_id(id); - region_info.set_start_key(start_key); - region_info.set_end_key(end_key); - if (maybe_epoch) - { - *region_info.mutable_region_epoch() = (maybe_epoch.value()); - } - else - { - region_info.mutable_region_epoch()->set_version(5); - region_info.mutable_region_epoch()->set_version(6); - } - if (maybe_peers) - { - const auto & peers = maybe_peers.value(); - for (const auto & peer : peers) - { - *(region_info.mutable_peers()->Add()) = peer; - } - } - else - { - *(region_info.mutable_peers()->Add()) = createPeer(1, true); - *(region_info.mutable_peers()->Add()) = createPeer(2, false); - } - - return region_info; -} - -inline RegionMeta createRegionMeta( - UInt64 id, - DB::TableID table_id, - std::optional apply_state = std::nullopt) -{ - return RegionMeta( - /*peer=*/createPeer(31, true), - /*region=*/createRegionInfo(id, RecordKVFormat::genKey(table_id, 0), RecordKVFormat::genKey(table_id, 300)), - /*apply_state_=*/apply_state.value_or(initialApplyState())); -} + std::optional> maybe_peers = std::nullopt); -inline RegionPtr makeRegion( +metapb::Region createMetaRegionCommonHandle( // + RegionID region_id, + const std::string & start_key, + const std::string & end_key, + std::optional maybe_epoch = std::nullopt, + std::optional> maybe_peers = std::nullopt); + +/// Utils to create a RegionPtr for testing. +/// - If your tests don't care about the raft-index in the created RegionPtr, +/// use the following `makeRegionForTable`/`makeRegionForRange`/`makeRegion`. +/// - If raft-index matters, try `MockTiKV::instance().{createRegion,createRegionCommonHandle}` + +RegionPtr makeRegionForTable( + UInt64 region_id, + TableID table_id, + HandleID start, + HandleID end, + const TiFlashRaftProxyHelper * proxy_helper = nullptr); + +RegionPtr makeRegionForRange( UInt64 id, - const std::string start_key, - const std::string end_key, - const TiFlashRaftProxyHelper * proxy_helper = nullptr) -{ - return std::make_shared( - RegionMeta( - createPeer(2, true), - createRegionInfo(id, std::move(start_key), std::move(end_key)), - initialApplyState()), - proxy_helper); -} - -inline RegionPtr makeRegion(RegionMeta && meta) -{ - return std::make_shared(std::move(meta), nullptr); -} + std::string start_key, + std::string end_key, + const TiFlashRaftProxyHelper * proxy_helper = nullptr); + +RegionPtr makeRegion(RegionMeta && meta); // Generates a lock value which fills all fields, only for test use. -inline TiKVValue encodeFullLockCfValue( +TiKVValue encodeFullLockCfValue( UInt8 lock_type, const String & primary, Timestamp ts, @@ -108,58 +78,77 @@ inline TiKVValue encodeFullLockCfValue( uint64_t txn_size, const std::vector & async_commit, const std::vector & rollback, - UInt64 generation = 0) -{ - auto lock_value = RecordKVFormat::encodeLockCfValue(lock_type, primary, ts, ttl, short_value, min_commit_ts); - WriteBufferFromOwnString res; - res.write(lock_value.getStr().data(), lock_value.getStr().size()); - { - res.write(RecordKVFormat::MIN_COMMIT_TS_PREFIX); - RecordKVFormat::encodeUInt64(min_commit_ts, res); - } - { - res.write(RecordKVFormat::FOR_UPDATE_TS_PREFIX); - RecordKVFormat::encodeUInt64(for_update_ts, res); - } - { - res.write(RecordKVFormat::TXN_SIZE_PREFIX); - RecordKVFormat::encodeUInt64(txn_size, res); - } - { - res.write(RecordKVFormat::ROLLBACK_TS_PREFIX); - TiKV::writeVarUInt(rollback.size(), res); - for (auto ts : rollback) - { - RecordKVFormat::encodeUInt64(ts, res); - } - } - { - res.write(RecordKVFormat::ASYNC_COMMIT_PREFIX); - TiKV::writeVarUInt(async_commit.size(), res); - for (const auto & s : async_commit) - { - writeVarInt(s.size(), res); - res.write(s.data(), s.size()); - } - } - { - res.write(RecordKVFormat::LAST_CHANGE_PREFIX); - RecordKVFormat::encodeUInt64(12345678, res); - TiKV::writeVarUInt(87654321, res); - } - { - res.write(RecordKVFormat::TXN_SOURCE_PREFIX_FOR_LOCK); - TiKV::writeVarUInt(876543, res); - } - { - res.write(RecordKVFormat::PESSIMISTIC_LOCK_WITH_CONFLICT_PREFIX); - } - if (generation > 0) - { - res.write(RecordKVFormat::GENERATION_PREFIX); - RecordKVFormat::encodeUInt64(generation, res); - } - return TiKVValue(res.releaseStr()); -} + UInt64 generation = 0); + +void setupPutRequest(raft_cmdpb::Request * req, const std::string & cf, const TiKVKey & key, const TiKVValue & value); +void setupDelRequest(raft_cmdpb::Request *, const std::string &, const TiKVKey &); + +void encodeRow(const TiDB::TableInfo & table_info, const std::vector & fields, WriteBuffer & ss); + +void insert( + const TiDB::TableInfo & table_info, + RegionID region_id, + HandleID handle_id, + ASTs::const_iterator begin, + ASTs::const_iterator end, + Context & context, + const std::optional> & tso_del = {}); + +void addRequestsToRaftCmd( + raft_cmdpb::RaftCmdRequest & request, + const TiKVKey & key, + const TiKVValue & value, + UInt64 prewrite_ts, + UInt64 commit_ts, + bool del, + const String & pk = "pk"); + +void concurrentBatchInsert( + const TiDB::TableInfo & table_info, + Int64 concurrent_num, + Int64 flush_num, + Int64 batch_num, + UInt64 min_strlen, + UInt64 max_strlen, + Context & context); + +void remove(const TiDB::TableInfo & table_info, RegionID region_id, HandleID handle_id, Context & context); + +Int64 concurrentRangeOperate( + const TiDB::TableInfo & table_info, + HandleID start_handle, + HandleID end_handle, + Context & context, + Int64 magic_num, + bool del); + +Field convertField(const TiDB::ColumnInfo & column_info, const Field & field); + +TableID getTableID( + Context & context, + const std::string & database_name, + const std::string & table_name, + const std::string & partition_id); + +const TiDB::TableInfo & getTableInfo(Context & context, const String & database_name, const String & table_name); + +EngineStoreApplyRes applyWriteRaftCmd( + KVStore & kvstore, + raft_cmdpb::RaftCmdRequest && request, + UInt64 region_id, + UInt64 index, + UInt64 term, + TMTContext & tmt, + ::DB::DM::WriteResult * write_result_ptr = nullptr); + +void handleApplySnapshot( + KVStore & kvstore, + metapb::Region && region, + uint64_t peer_id, + SSTViewVec, + uint64_t index, + uint64_t term, + std::optional, + TMTContext & tmt); } // namespace DB::RegionBench diff --git a/dbms/src/Debug/ReadIndexStressTest.cpp b/dbms/src/Debug/ReadIndexStressTest.cpp index 3fd813c884e..02601c42bea 100644 --- a/dbms/src/Debug/ReadIndexStressTest.cpp +++ b/dbms/src/Debug/ReadIndexStressTest.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include diff --git a/dbms/src/Debug/dbgKVStore/dbgFuncMockRaftCommand.cpp b/dbms/src/Debug/dbgKVStore/dbgFuncMockRaftCommand.cpp index 6e3618d94dc..48aae3fd46f 100644 --- a/dbms/src/Debug/dbgKVStore/dbgFuncMockRaftCommand.cpp +++ b/dbms/src/Debug/dbgKVStore/dbgFuncMockRaftCommand.cpp @@ -138,7 +138,7 @@ void MockRaftCommand::dbgFuncRegionBatchSplit(Context & context, const ASTs & ar std::move(request), std::move(response), region_id, - MockTiKV::instance().getRaftIndex(region_id), + MockTiKV::instance().getNextRaftIndex(region_id), MockTiKV::instance().getRaftTerm(region_id), tmt); @@ -179,7 +179,7 @@ void MockRaftCommand::dbgFuncPrepareMerge(Context & context, const ASTs & args, std::move(request), std::move(response), region_id, - MockTiKV::instance().getRaftIndex(region_id), + MockTiKV::instance().getNextRaftIndex(region_id), MockTiKV::instance().getRaftTerm(region_id), tmt); @@ -216,7 +216,7 @@ void MockRaftCommand::dbgFuncCommitMerge(Context & context, const ASTs & args, D std::move(request), std::move(response), current_id, - MockTiKV::instance().getRaftIndex(current_id), + MockTiKV::instance().getNextRaftIndex(current_id), MockTiKV::instance().getRaftTerm(current_id), tmt); @@ -252,7 +252,7 @@ void MockRaftCommand::dbgFuncRollbackMerge(Context & context, const ASTs & args, std::move(request), std::move(response), region_id, - MockTiKV::instance().getRaftIndex(region_id), + MockTiKV::instance().getNextRaftIndex(region_id), MockTiKV::instance().getRaftTerm(region_id), tmt); diff --git a/dbms/src/Debug/dbgKVStore/dbgFuncMockRaftSnapshot.cpp b/dbms/src/Debug/dbgKVStore/dbgFuncMockRaftSnapshot.cpp index ebb844711c7..3fb6f02a668 100644 --- a/dbms/src/Debug/dbgKVStore/dbgFuncMockRaftSnapshot.cpp +++ b/dbms/src/Debug/dbgKVStore/dbgFuncMockRaftSnapshot.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -41,9 +42,9 @@ #include #include #include +#include #include #include -#include #include #include @@ -81,7 +82,7 @@ RegionPtr GenDbgRegionSnapshotWithData(Context & context, const ASTs & args) { auto start = static_cast(safeGet(typeid_cast(*args[3]).value)); auto end = static_cast(safeGet(typeid_cast(*args[4]).value)); - region = RegionBench::createRegion(table_id, region_id, start, end); + region = MockTiKV::instance().createRegion(table_id, region_id, start, end); } else { @@ -102,7 +103,7 @@ RegionPtr GenDbgRegionSnapshotWithData(Context & context, const ASTs & args) TiDB::DatumBumpy end_datum = TiDB::DatumBumpy(end_field, column_info.tp); end_keys.emplace_back(end_datum.field()); } - region = RegionBench::createRegion(table_info, region_id, start_keys, end_keys); + region = MockTiKV::instance().createRegionCommonHandle(table_info, region_id, start_keys, end_keys); } auto args_begin = args.begin() + 3 + handle_column_size * 2; @@ -160,7 +161,7 @@ RegionPtr GenDbgRegionSnapshotWithData(Context & context, const ASTs & args) std::move(commit_key), std::move(commit_value)); } - MockTiKV::instance().getRaftIndex(region_id); + MockTiKV::instance().getNextRaftIndex(region_id); } return region; } @@ -213,11 +214,8 @@ void MockRaftCommand::dbgFuncRegionSnapshot(Context & context, const ASTs & args TMTContext & tmt = context.getTMTContext(); - metapb::Region region_info; - TiKVKey start_key; TiKVKey end_key; - region_info.set_id(region_id); if (table_info.is_common_handle) { // Get start key and end key form multiple column if it is clustered_index. @@ -247,10 +245,13 @@ void MockRaftCommand::dbgFuncRegionSnapshot(Context & context, const ASTs & args start_key = RecordKVFormat::genKey(table_id, start); end_key = RecordKVFormat::genKey(table_id, end); } - region_info.set_start_key(start_key.toString()); - region_info.set_end_key(end_key.toString()); - *region_info.add_peers() = tests::createPeer(1, true); - *region_info.add_peers() = tests::createPeer(2, true); + metapb::Region region_info = RegionBench::createMetaRegionCommonHandle( + region_id, + start_key.toString(), + end_key.toString(), + std::nullopt, + std::vector{RegionBench::createPeer(1, true), RegionBench::createPeer(2, true)}); + auto peer_id = 1; auto start_decoded_key = RecordKVFormat::decodeTiKVKey(start_key); auto end_decoded_key = RecordKVFormat::decodeTiKVKey(end_key); @@ -261,7 +262,7 @@ void MockRaftCommand::dbgFuncRegionSnapshot(Context & context, const ASTs & args std::move(region_info), peer_id, SSTViewVec{nullptr, 0}, - MockTiKV::instance().getRaftIndex(region_id), + MockTiKV::instance().getNextRaftIndex(region_id), RAFT_INIT_LOG_TERM, std::nullopt, tmt); @@ -474,7 +475,7 @@ void MockRaftCommand::dbgFuncIngestSST(Context & context, const ASTs & args, DBG kvstore->handleIngestSST( region_id, SSTViewVec{sst_views.data(), sst_views.size()}, - MockTiKV::instance().getRaftIndex(region_id), + MockTiKV::instance().getNextRaftIndex(region_id), MockTiKV::instance().getRaftTerm(region_id), tmt); } @@ -489,7 +490,7 @@ void MockRaftCommand::dbgFuncIngestSST(Context & context, const ASTs & args, DBG kvstore->handleIngestSST( region_id, SSTViewVec{sst_views.data(), sst_views.size()}, - MockTiKV::instance().getRaftIndex(region_id), + MockTiKV::instance().getNextRaftIndex(region_id), MockTiKV::instance().getRaftTerm(region_id), tmt); } @@ -595,8 +596,8 @@ void MockRaftCommand::dbgFuncRegionSnapshotPreHandleDTFiles( // We may call this function mutiple time to mock some situation, try to reuse the region in `GLOBAL_REGION_MAP` // so that we can collect uncommitted data. - UInt64 index = MockTiKV::instance().getRaftIndex(region_id) + 1; - RegionPtr new_region = RegionBench::createRegion(table->id(), region_id, start_handle, end_handle + 10000, index); + RegionPtr new_region = MockTiKV::instance().createRegion(table->id(), region_id, start_handle, end_handle + 10000); + UInt64 index = new_region->appliedIndex() + 1; // Register some mock SST reading methods so that we can decode data in `MockSSTReader::MockSSTData` RegionMockTest mock_test(kvstore.get(), new_region); @@ -698,11 +699,11 @@ void MockRaftCommand::dbgFuncRegionSnapshotPreHandleDTFilesWithHandles( // We may call this function mutiple time to mock some situation, try to reuse the region in `GLOBAL_REGION_MAP` // so that we can collect uncommitted data. - UInt64 index = MockTiKV::instance().getRaftIndex(region_id) + 1; UInt64 region_start_handle = handles[0]; UInt64 region_end_handle = handles.back() + 10000; RegionPtr new_region - = RegionBench::createRegion(table->id(), region_id, region_start_handle, region_end_handle, index); + = MockTiKV::instance().createRegion(table->id(), region_id, region_start_handle, region_end_handle); + UInt64 index = new_region->appliedIndex() + 1; // Register some mock SST reading methods so that we can decode data in `MockSSTReader::MockSSTData` RegionMockTest mock_test(kvstore.get(), new_region); diff --git a/dbms/src/Debug/dbgKVStore/dbgFuncRegion.cpp b/dbms/src/Debug/dbgKVStore/dbgFuncRegion.cpp index 7b44e41c45c..2f3f0d0db56 100644 --- a/dbms/src/Debug/dbgKVStore/dbgFuncRegion.cpp +++ b/dbms/src/Debug/dbgKVStore/dbgFuncRegion.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -80,7 +81,7 @@ void dbgFuncPutRegion(Context & context, const ASTs & args, DBGInvoker::Printer } TMTContext & tmt = context.getTMTContext(); - RegionPtr region = RegionBench::createRegion(table_info, region_id, start_keys, end_keys); + RegionPtr region = MockTiKV::instance().createRegionCommonHandle(table_info, region_id, start_keys, end_keys); tmt.getKVStore() ->onSnapshot(RegionPtrWithSnapshotFiles{region, {}}, nullptr, 0, tmt); @@ -96,7 +97,7 @@ void dbgFuncPutRegion(Context & context, const ASTs & args, DBGInvoker::Printer auto end = static_cast(safeGet(typeid_cast(*args[2]).value)); TMTContext & tmt = context.getTMTContext(); - RegionPtr region = RegionBench::createRegion(table_id, region_id, start, end); + RegionPtr region = MockTiKV::instance().createRegion(table_id, region_id, start, end); tmt.getKVStore()->onSnapshot(region, nullptr, 0, tmt); output(fmt::format( diff --git a/dbms/src/Debug/dbgKVStore/dbgKVStore.h b/dbms/src/Debug/dbgKVStore/dbgKVStore.h index 1ab39dee834..fb39e0edaa7 100644 --- a/dbms/src/Debug/dbgKVStore/dbgKVStore.h +++ b/dbms/src/Debug/dbgKVStore/dbgKVStore.h @@ -17,6 +17,7 @@ #include #include #include +#include namespace DB::RegionBench { @@ -39,4 +40,4 @@ struct DebugKVStore private: KVStore & kvstore; }; -} // namespace DB::RegionBench \ No newline at end of file +} // namespace DB::RegionBench diff --git a/dbms/src/Debug/dbgKVStore/dbgRegion.h b/dbms/src/Debug/dbgKVStore/dbgRegion.h index ec3a1bb31c2..88c412ec50b 100644 --- a/dbms/src/Debug/dbgKVStore/dbgRegion.h +++ b/dbms/src/Debug/dbgKVStore/dbgRegion.h @@ -14,6 +14,8 @@ #pragma once +#include +#include #include namespace DB::RegionBench @@ -31,4 +33,4 @@ struct DebugRegion private: Region & region; }; -} // namespace DB::RegionBench \ No newline at end of file +} // namespace DB::RegionBench diff --git a/dbms/src/Debug/dbgNaturalDag.cpp b/dbms/src/Debug/dbgNaturalDag.cpp index 33d6aeb32ae..aba772f12ce 100644 --- a/dbms/src/Debug/dbgNaturalDag.cpp +++ b/dbms/src/Debug/dbgNaturalDag.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include @@ -27,6 +28,8 @@ #include #include #include +#include +#include #include #include @@ -223,19 +226,15 @@ void NaturalDag::buildTables(Context & context) schema_syncer->syncSchemas(context, NullspaceID); for (auto & region : table.regions) { - metapb::Region region_pb; metapb::Peer peer; - region_pb.set_id(region.id); - region_pb.set_start_key(region.start.getStr()); - region_pb.set_end_key(region.end.getStr()); - RegionMeta region_meta(std::move(peer), std::move(region_pb), initialApplyState()); - auto raft_index = RAFT_INIT_LOG_INDEX; - region_meta.setApplied(raft_index, RAFT_INIT_LOG_TERM); + RegionMeta region_meta( + std::move(peer), + RegionBench::createMetaRegionCommonHandle(region.id, region.start.getStr(), region.end.getStr()), + initialApplyState()); RegionPtr region_ptr = RegionBench::makeRegion(std::move(region_meta)); tmt.getKVStore()->onSnapshot(region_ptr, nullptr, 0, tmt); - auto & pairs = region.pairs; - for (auto & pair : pairs) + for (auto & pair : region.pairs) { UInt64 prewrite_ts = pd_client->getTS(); UInt64 commit_ts = pd_client->getTS(); @@ -245,7 +244,7 @@ void NaturalDag::buildTables(Context & context) *tmt.getKVStore(), std::move(request), region.id, - MockTiKV::instance().getRaftIndex(region.id), + MockTiKV::instance().getNextRaftIndex(region.id), MockTiKV::instance().getRaftTerm(region.id), tmt); } diff --git a/dbms/src/Debug/dbgNaturalDag.h b/dbms/src/Debug/dbgNaturalDag.h index 20d2a6cc6c7..d8970d6f27f 100644 --- a/dbms/src/Debug/dbgNaturalDag.h +++ b/dbms/src/Debug/dbgNaturalDag.h @@ -16,7 +16,6 @@ #include #include -#include #include #include #include diff --git a/dbms/src/Debug/dbgQueryExecutor.cpp b/dbms/src/Debug/dbgQueryExecutor.cpp index f3906a56a33..615a981e0c2 100644 --- a/dbms/src/Debug/dbgQueryExecutor.cpp +++ b/dbms/src/Debug/dbgQueryExecutor.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include namespace DB diff --git a/dbms/src/Debug/dbgTools.cpp b/dbms/src/Debug/dbgTools.cpp index 2448d2d4605..156596a6db1 100644 --- a/dbms/src/Debug/dbgTools.cpp +++ b/dbms/src/Debug/dbgTools.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -23,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -30,769 +32,14 @@ #include #include -#include -namespace DB -{ -namespace ErrorCodes +namespace DB::ErrorCodes { extern const int LOGICAL_ERROR; extern const int UNKNOWN_TABLE; extern const int UNKNOWN_DATABASE; -} // namespace ErrorCodes - -namespace RegionBench -{ -using TiDB::ColumnInfo; - -RegionPtr createRegion( - TableID table_id, - RegionID region_id, - const HandleID & start, - const HandleID & end, - std::optional index_) -{ - metapb::Region region; - metapb::Peer peer; - region.set_id(region_id); - - TiKVKey start_key = RecordKVFormat::genKey(table_id, start); - TiKVKey end_key = RecordKVFormat::genKey(table_id, end); - - region.set_start_key(start_key.getStr()); - region.set_end_key(end_key.getStr()); - - RegionMeta region_meta(std::move(peer), std::move(region), initialApplyState()); - uint64_t index = MockTiKV::instance().getRaftIndex(region_id); - if (index_) - index = *index_; - region_meta.setApplied(index, RAFT_INIT_LOG_TERM); - return makeRegion(std::move(region_meta)); -} - -Regions createRegions( - TableID table_id, - size_t region_num, - size_t key_num_each_region, - HandleID handle_begin, - RegionID new_region_id_begin) -{ - Regions regions; - for (RegionID region_id = new_region_id_begin; region_id < static_cast(new_region_id_begin + region_num); - ++region_id, handle_begin += key_num_each_region) - { - auto ptr = createRegion(table_id, region_id, handle_begin, handle_begin + key_num_each_region); - regions.push_back(ptr); - } - return regions; -} - -RegionPtr createRegion( - const TiDB::TableInfo & table_info, - RegionID region_id, - std::vector & start_keys, - std::vector & end_keys) -{ - metapb::Region region; - metapb::Peer peer; - region.set_id(region_id); - - TiKVKey start_key = RecordKVFormat::genKey(table_info, start_keys); - TiKVKey end_key = RecordKVFormat::genKey(table_info, end_keys); - - region.set_start_key(start_key.getStr()); - region.set_end_key(end_key.getStr()); - - RegionMeta region_meta(std::move(peer), std::move(region), initialApplyState()); - region_meta.setApplied(MockTiKV::instance().getRaftIndex(region_id), RAFT_INIT_LOG_TERM); - return RegionBench::makeRegion(std::move(region_meta)); -} - -void setupPutRequest(raft_cmdpb::Request * req, const std::string & cf, const TiKVKey & key, const TiKVValue & value) -{ - req->set_cmd_type(raft_cmdpb::CmdType::Put); - raft_cmdpb::PutRequest * put = req->mutable_put(); - put->set_cf(cf.c_str()); - put->set_key(key.getStr()); - put->set_value(value.getStr()); -} - -void setupDelRequest(raft_cmdpb::Request * req, const std::string & cf, const TiKVKey & key) -{ - req->set_cmd_type(raft_cmdpb::CmdType::Delete); - raft_cmdpb::DeleteRequest * del = req->mutable_delete_(); - del->set_cf(cf.c_str()); - del->set_key(key.getStr()); -} - -void addRequestsToRaftCmd( - raft_cmdpb::RaftCmdRequest & request, - const TiKVKey & key, - const TiKVValue & value, - UInt64 prewrite_ts, - UInt64 commit_ts, - bool del, - const String pk) -{ - TiKVKey commit_key = RecordKVFormat::appendTs(key, commit_ts); - const TiKVKey & lock_key = key; - - if (del) - { - TiKVValue lock_value = RecordKVFormat::encodeLockCfValue(Region::DelFlag, pk, prewrite_ts, 0); - TiKVValue commit_value = RecordKVFormat::encodeWriteCfValue(Region::DelFlag, prewrite_ts); - - setupPutRequest(request.add_requests(), ColumnFamilyName::Lock, lock_key, lock_value); - setupPutRequest(request.add_requests(), ColumnFamilyName::Write, commit_key, commit_value); - setupDelRequest(request.add_requests(), ColumnFamilyName::Lock, lock_key); - return; - } - - if (value.dataSize() <= RecordKVFormat::SHORT_VALUE_MAX_LEN) - { - TiKVValue lock_value = RecordKVFormat::encodeLockCfValue(Region::PutFlag, pk, prewrite_ts, 0); - - TiKVValue commit_value = RecordKVFormat::encodeWriteCfValue(Region::PutFlag, prewrite_ts, value.toString()); - - setupPutRequest(request.add_requests(), ColumnFamilyName::Lock, lock_key, lock_value); - setupPutRequest(request.add_requests(), ColumnFamilyName::Write, commit_key, commit_value); - setupDelRequest(request.add_requests(), ColumnFamilyName::Lock, lock_key); - } - else - { - TiKVValue lock_value = RecordKVFormat::encodeLockCfValue(Region::PutFlag, pk, prewrite_ts, 0); - - TiKVKey prewrite_key = RecordKVFormat::appendTs(key, prewrite_ts); - const TiKVValue & prewrite_value = value; - - TiKVValue commit_value = RecordKVFormat::encodeWriteCfValue(Region::PutFlag, prewrite_ts); - - setupPutRequest(request.add_requests(), ColumnFamilyName::Lock, lock_key, lock_value); - setupPutRequest(request.add_requests(), ColumnFamilyName::Write, commit_key, commit_value); - setupPutRequest(request.add_requests(), ColumnFamilyName::Default, prewrite_key, prewrite_value); - setupDelRequest(request.add_requests(), ColumnFamilyName::Lock, lock_key); - } -} - -template -T convertNumber(const Field & field) -{ - switch (field.getType()) - { - case Field::Types::Int64: - return static_cast(field.get()); - case Field::Types::UInt64: - return static_cast(field.get()); - case Field::Types::Float64: - return static_cast(field.get()); - case Field::Types::Decimal32: - return static_cast(field.get>()); - case Field::Types::Decimal64: - return static_cast(field.get>()); - case Field::Types::Decimal128: - return static_cast(field.get>()); - case Field::Types::Decimal256: - return static_cast(field.get>()); - default: - throw Exception( - String("Unable to convert field type ") + field.getTypeName() + " to number", - ErrorCodes::LOGICAL_ERROR); - } -} - -Field convertDecimal(const ColumnInfo & column_info, const Field & field) -{ - switch (field.getType()) - { - case Field::Types::Int64: - return column_info.getDecimalValue(std::to_string(field.get())); - case Field::Types::UInt64: - return column_info.getDecimalValue(std::to_string(field.get())); - case Field::Types::Float64: - return column_info.getDecimalValue(std::to_string(field.get())); - case Field::Types::Decimal32: - return column_info.getDecimalValue(field.get().toString(column_info.decimal)); - case Field::Types::Decimal64: - return column_info.getDecimalValue(field.get().toString(column_info.decimal)); - case Field::Types::Decimal128: - return column_info.getDecimalValue(field.get().toString(column_info.decimal)); - case Field::Types::Decimal256: - return column_info.getDecimalValue(field.get().toString(column_info.decimal)); - default: - throw Exception( - String("Unable to convert field type ") + field.getTypeName() + " to number", - ErrorCodes::LOGICAL_ERROR); - } -} - -Field convertEnum(const ColumnInfo & column_info, const Field & field) -{ - switch (field.getType()) - { - case Field::Types::Int64: - case Field::Types::UInt64: - return convertNumber(field); - case Field::Types::String: - return static_cast(column_info.getEnumIndex(field.get())); - default: - throw Exception( - String("Unable to convert field type ") + field.getTypeName() + " to Enum", - ErrorCodes::LOGICAL_ERROR); - } -} - -Field convertField(const ColumnInfo & column_info, const Field & field) -{ - if (field.isNull()) - return field; - - switch (column_info.tp) - { - case TiDB::TypeTiny: - case TiDB::TypeShort: - case TiDB::TypeLong: - case TiDB::TypeLongLong: - case TiDB::TypeInt24: - if (column_info.hasUnsignedFlag()) - return convertNumber(field); - else - return convertNumber(field); - case TiDB::TypeFloat: - case TiDB::TypeDouble: - return convertNumber(field); - case TiDB::TypeDate: - case TiDB::TypeDatetime: - case TiDB::TypeTimestamp: - return parseMyDateTime(field.safeGet()); - case TiDB::TypeVarchar: - case TiDB::TypeTinyBlob: - case TiDB::TypeMediumBlob: - case TiDB::TypeLongBlob: - case TiDB::TypeBlob: - case TiDB::TypeVarString: - case TiDB::TypeString: - return field; - case TiDB::TypeEnum: - return convertEnum(column_info, field); - case TiDB::TypeNull: - return Field(); - case TiDB::TypeDecimal: - case TiDB::TypeNewDecimal: - return convertDecimal(column_info, field); - case TiDB::TypeTime: - case TiDB::TypeYear: - return convertNumber(field); - case TiDB::TypeSet: - case TiDB::TypeBit: - return convertNumber(field); - default: - return Field(); - } -} - -void encodeRow(const TiDB::TableInfo & table_info, const std::vector & fields, WriteBuffer & ss) -{ - if (table_info.columns.size() < fields.size() + table_info.pk_is_handle) - throw Exception( - "Encoding row has less columns than encode values [num_columns=" + DB::toString(table_info.columns.size()) - + "] [num_fields=" + DB::toString(fields.size()) + "] . ", - ErrorCodes::LOGICAL_ERROR); - - std::vector flatten_fields; - std::unordered_set pk_column_names; - if (table_info.is_common_handle) - { - for (const auto & idx_col : table_info.getPrimaryIndexInfo().idx_cols) - { - // todo support prefix index - pk_column_names.insert(idx_col.name); - } - } - for (size_t i = 0; i < fields.size(); i++) - { - const auto & column_info = table_info.columns[i]; - /// skip the columns encoded in the key - if (pk_column_names.find(column_info.name) != pk_column_names.end()) - continue; - Field field = convertField(column_info, fields[i]); - TiDB::DatumBumpy datum = TiDB::DatumBumpy(field, column_info.tp); - flatten_fields.emplace_back(datum.field()); - } - - static bool row_format_flip = false; - // Ping-pong encoding using row format V1/V2. - (row_format_flip = !row_format_flip) ? encodeRowV1(table_info, flatten_fields, ss) - : encodeRowV2(table_info, flatten_fields, ss); -} - -void insert( // - const TiDB::TableInfo & table_info, - RegionID region_id, - HandleID handle_id, // - ASTs::const_iterator values_begin, - ASTs::const_iterator values_end, // - Context & context, - const std::optional> & tso_del) -{ - // Parse the fields in the inserted row - std::vector fields; - { - for (auto it = values_begin; it != values_end; ++it) - { - auto field = typeid_cast((*it).get())->value; - fields.emplace_back(field); - } - if (fields.size() + table_info.pk_is_handle != table_info.columns.size()) - throw Exception("Number of insert values and columns do not match.", ErrorCodes::LOGICAL_ERROR); - } - TMTContext & tmt = context.getTMTContext(); - pingcap::pd::ClientPtr pd_client = tmt.getPDClient(); - RegionPtr region = tmt.getKVStore()->getRegion(region_id); - - // Using the region meta's table ID rather than table_info's, as this could be a partition table so that the table ID should be partition ID. - const auto range = region->getRange(); - TableID table_id = RecordKVFormat::getTableId(*range->rawKeys().first); - - TiKVKey key; - if (table_info.is_common_handle) - { - std::vector keys; - const auto & pk_index = table_info.getPrimaryIndexInfo(); - for (const auto & idx_col : pk_index.idx_cols) - { - const auto & column_info = table_info.columns[idx_col.offset]; - auto start_field = RegionBench::convertField(column_info, fields[idx_col.offset]); - TiDB::DatumBumpy start_datum = TiDB::DatumBumpy(start_field, column_info.tp); - keys.emplace_back(start_datum.field()); - } - key = RecordKVFormat::genKey(table_info, keys); - } - else - key = RecordKVFormat::genKey(table_id, handle_id); - WriteBufferFromOwnString ss; - encodeRow(table_info, fields, ss); - TiKVValue value(ss.releaseStr()); - - UInt64 prewrite_ts = pd_client->getTS(); - UInt64 commit_ts = pd_client->getTS(); - bool is_del = false; - - if (tso_del.has_value()) - { - auto [tso, del] = *tso_del; - prewrite_ts = tso; - commit_ts = tso; - is_del = del; - } - - raft_cmdpb::RaftCmdRequest request; - addRequestsToRaftCmd(request, key, value, prewrite_ts, commit_ts, is_del); - RegionBench::applyWriteRaftCmd( - *tmt.getKVStore(), - std::move(request), - region_id, - MockTiKV::instance().getRaftIndex(region_id), - MockTiKV::instance().getRaftTerm(region_id), - tmt); -} - -void remove(const TiDB::TableInfo & table_info, RegionID region_id, HandleID handle_id, Context & context) -{ - static const TiKVValue value; - - TiKVKey key = RecordKVFormat::genKey(table_info.id, handle_id); - - TMTContext & tmt = context.getTMTContext(); - pingcap::pd::ClientPtr pd_client = tmt.getPDClient(); - RegionPtr region = tmt.getKVStore()->getRegion(region_id); - - UInt64 prewrite_ts = pd_client->getTS(); - UInt64 commit_ts = pd_client->getTS(); - - raft_cmdpb::RaftCmdRequest request; - addRequestsToRaftCmd(request, key, value, prewrite_ts, commit_ts, true); - RegionBench::applyWriteRaftCmd( - *tmt.getKVStore(), - std::move(request), - region_id, - MockTiKV::instance().getRaftIndex(region_id), - MockTiKV::instance().getRaftTerm(region_id), - tmt); -} - -struct BatchCtrl -{ - String default_str; - Int64 concurrent_id; - Int64 flush_num; - Int64 batch_num; - UInt64 min_strlen; - UInt64 max_strlen; - Context * context; - RegionPtr region; - HandleID handle_begin; - bool del; - - BatchCtrl( - Int64 concurrent_id_, - Int64 flush_num_, - Int64 batch_num_, - UInt64 min_strlen_, - UInt64 max_strlen_, - Context * context_, - RegionPtr region_, - HandleID handle_begin_, - bool del_) - : concurrent_id(concurrent_id_) - , flush_num(flush_num_) - , batch_num(batch_num_) - , min_strlen(min_strlen_) - , max_strlen(max_strlen_) - , context(context_) - , region(region_) - , handle_begin(handle_begin_) - , del(del_) - { - assert(max_strlen >= min_strlen); - assert(min_strlen >= 1); - auto str_len = static_cast(random() % (max_strlen - min_strlen + 1) + min_strlen); - default_str = String(str_len, '_'); - } - - void encodeDatum(WriteBuffer & ss, TiDB::CodecFlag flag, Int64 magic_num) - { - Int8 target = (magic_num % 70) + '0'; - EncodeUInt(static_cast(flag), ss); - switch (flag) - { - case TiDB::CodecFlagJson: - throw Exception( - "Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagJson", - ErrorCodes::LOGICAL_ERROR); - case TiDB::CodecFlagVectorFloat32: - throw Exception( - "Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagVectorFloat32", - ErrorCodes::LOGICAL_ERROR); - case TiDB::CodecFlagMax: - throw Exception("Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagMax", ErrorCodes::LOGICAL_ERROR); - case TiDB::CodecFlagDuration: - throw Exception( - "Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagDuration", - ErrorCodes::LOGICAL_ERROR); - case TiDB::CodecFlagNil: - return; - case TiDB::CodecFlagBytes: - memset(default_str.data(), target, default_str.size()); - return EncodeBytes(default_str, ss); - //case TiDB::CodecFlagDecimal: - // return EncodeDecimal(Decimal(magic_num), ss); - case TiDB::CodecFlagCompactBytes: - memset(default_str.data(), target, default_str.size()); - return EncodeCompactBytes(default_str, ss); - case TiDB::CodecFlagFloat: - return EncodeFloat64(static_cast(magic_num) / 1111.1, ss); - case TiDB::CodecFlagUInt: - return EncodeUInt(static_cast(magic_num), ss); - case TiDB::CodecFlagInt: - return EncodeInt64((magic_num), ss); - case TiDB::CodecFlagVarInt: - return EncodeVarInt((magic_num), ss); - case TiDB::CodecFlagVarUInt: - return EncodeVarUInt(static_cast(magic_num), ss); - default: - throw Exception("Not implented codec flag: " + std::to_string(flag), ErrorCodes::LOGICAL_ERROR); - } - } - - TiKVValue encodeRow(const TiDB::TableInfo & table_info, Int64 magic_num) - { - WriteBufferFromOwnString ss; - for (const auto & column : table_info.columns) - { - encodeDatum(ss, TiDB::CodecFlagInt, column.id); - // TODO: May need to use BumpyDatum to flatten before encoding. - encodeDatum(ss, column.getCodecFlag(), magic_num); - } - return TiKVValue(ss.releaseStr()); - } -}; - -void batchInsert( - const TiDB::TableInfo & table_info, - std::unique_ptr batch_ctrl, - std::function fn_gen_magic_num) -{ - RegionPtr & region = batch_ctrl->region; - - TMTContext & tmt = batch_ctrl->context->getTMTContext(); - pingcap::pd::ClientPtr pd_client = tmt.getPDClient(); - - Int64 index = batch_ctrl->handle_begin; - - for (Int64 flush_cnt = 0; flush_cnt < batch_ctrl->flush_num; ++flush_cnt) - { - UInt64 prewrite_ts = pd_client->getTS(); - UInt64 commit_ts = pd_client->getTS(); - - raft_cmdpb::RaftCmdRequest request; - - for (Int64 cnt = 0; cnt < batch_ctrl->batch_num; ++index, ++cnt) - { - TiKVKey key = RecordKVFormat::genKey(table_info.id, index); - TiKVValue value = batch_ctrl->encodeRow(table_info, fn_gen_magic_num(index)); - addRequestsToRaftCmd(request, key, value, prewrite_ts, commit_ts, batch_ctrl->del); - } - - RegionBench::applyWriteRaftCmd( - *tmt.getKVStore(), - std::move(request), - region->id(), - MockTiKV::instance().getRaftIndex(region->id()), - MockTiKV::instance().getRaftTerm(region->id()), - tmt); - } -} - -void concurrentBatchInsert( - const TiDB::TableInfo & table_info, - Int64 concurrent_num, - Int64 flush_num, - Int64 batch_num, - UInt64 min_strlen, - UInt64 max_strlen, - Context & context) -{ - TMTContext & tmt = context.getTMTContext(); - - RegionID curr_max_region_id(InvalidRegionID); - HandleID curr_max_handle_id = 0; - tmt.getKVStore()->traverseRegions([&](const RegionID region_id, const RegionPtr & region) { - curr_max_region_id - = (curr_max_region_id == InvalidRegionID) ? region_id : std::max(curr_max_region_id, region_id); - const auto range = region->getRange(); - curr_max_handle_id = std::max(RecordKVFormat::getHandle(*range->rawKeys().second), curr_max_handle_id); - }); - - Int64 key_num_each_region = flush_num * batch_num; - HandleID handle_begin = curr_max_handle_id; - +} // namespace DB::ErrorCodes - auto debug_kvstore = RegionBench::DebugKVStore(*tmt.getKVStore()); - Regions regions - = createRegions(table_info.id, concurrent_num, key_num_each_region, handle_begin, curr_max_region_id + 1); - for (const RegionPtr & region : regions) - debug_kvstore.onSnapshot(RegionPtrWithSnapshotFiles{region, {}}, nullptr, 0, tmt); - - std::list threads; - for (Int64 i = 0; i < concurrent_num; i++, handle_begin += key_num_each_region) - { - auto batch_ptr = std::make_unique< - BatchCtrl>(i, flush_num, batch_num, min_strlen, max_strlen, &context, regions[i], handle_begin, false); - threads.push_back( - std::thread(&batchInsert, table_info, std::move(batch_ptr), [](Int64 index) -> Int64 { return index; })); - } - for (auto & thread : threads) - { - thread.join(); - } -} - -Int64 concurrentRangeOperate( - const TiDB::TableInfo & table_info, - HandleID start_handle, - HandleID end_handle, - Context & context, - Int64 magic_num, - bool del) -{ - Regions regions; - - { - TMTContext & tmt = context.getTMTContext(); - for (auto && [_, r] : tmt.getRegionTable().getRegionsByTable(NullspaceID, table_info.id)) - { - std::ignore = _; - if (r == nullptr) - continue; - regions.push_back(r); - } - } - - std::shuffle(regions.begin(), regions.end(), std::default_random_engine()); - - std::list threads; - Int64 tol = 0; - for (const auto & region : regions) - { - const auto range = region->getRange(); - const auto & [ss, ee] = getHandleRangeByTable(range->rawKeys(), table_info.id); - TiKVRange::Handle handle_begin = std::max(ss, start_handle); - TiKVRange::Handle handle_end = std::min(ee, end_handle); - if (handle_end <= handle_begin) - continue; - Int64 batch_num = handle_end - handle_begin; - tol += batch_num; - auto batch_ptr - = std::make_unique(-1, 1, batch_num, 1, 1, &context, region, handle_begin.handle_id, del); - threads.push_back(std::thread(&batchInsert, table_info, std::move(batch_ptr), [=](Int64 index) -> Int64 { - std::ignore = index; - return magic_num; - })); - } - for (auto & thread : threads) - { - thread.join(); - } - return tol; -} - -TableID getTableID( - Context & context, - const std::string & database_name, - const std::string & table_name, - const std::string & partition_id) -{ - try - { - using TablePtr = MockTiDB::TablePtr; - TablePtr table = MockTiDB::instance().getTableByName(database_name, table_name); - - if (table->isPartitionTable()) - return std::strtol(partition_id.c_str(), nullptr, 0); - - return table->id(); - } - catch (Exception & e) - { - if (e.code() != ErrorCodes::UNKNOWN_TABLE) - throw; - } - - auto mapped_table_name = mappedTable(context, database_name, table_name).second; - auto mapped_database_name = mappedDatabase(context, database_name); - auto storage = context.getTable(mapped_database_name, mapped_table_name); - auto managed_storage = std::static_pointer_cast(storage); - auto table_info = managed_storage->getTableInfo(); - return table_info.id; -} - -const TiDB::TableInfo & getTableInfo(Context & context, const String & database_name, const String & table_name) -{ - try - { - using TablePtr = MockTiDB::TablePtr; - TablePtr table = MockTiDB::instance().getTableByName(database_name, table_name); - - return table->table_info; - } - catch (Exception & e) - { - if (e.code() != ErrorCodes::UNKNOWN_TABLE) - throw; - } - - auto mapped_table_name = mappedTable(context, database_name, table_name).second; - auto mapped_database_name = mappedDatabase(context, database_name); - auto storage = context.getTable(mapped_database_name, mapped_table_name); - auto managed_storage = std::static_pointer_cast(storage); - return managed_storage->getTableInfo(); -} - - -EngineStoreApplyRes applyWriteRaftCmd( - KVStore & kvstore, - raft_cmdpb::RaftCmdRequest && request, - UInt64 region_id, - UInt64 index, - UInt64 term, - TMTContext & tmt, - DM::WriteResult * write_result_ptr) -{ - std::vector keys; - std::vector vals; - std::vector cmd_types; - std::vector cmd_cf; - keys.reserve(request.requests_size()); - vals.reserve(request.requests_size()); - cmd_types.reserve(request.requests_size()); - cmd_cf.reserve(request.requests_size()); - - for (const auto & req : request.requests()) - { - auto type = req.cmd_type(); - - switch (type) - { - case raft_cmdpb::CmdType::Put: - keys.push_back({req.put().key().data(), req.put().key().size()}); - vals.push_back({req.put().value().data(), req.put().value().size()}); - cmd_types.push_back(WriteCmdType::Put); - cmd_cf.push_back(NameToCF(req.put().cf())); - break; - case raft_cmdpb::CmdType::Delete: - keys.push_back({req.delete_().key().data(), req.delete_().key().size()}); - vals.push_back({nullptr, 0}); - cmd_types.push_back(WriteCmdType::Del); - cmd_cf.push_back(NameToCF(req.delete_().cf())); - break; - default: - throw Exception( - fmt::format("Unsupport raft cmd {}", raft_cmdpb::CmdType_Name(type)), - ErrorCodes::LOGICAL_ERROR); - } - } - if (write_result_ptr) - { - return kvstore.handleWriteRaftCmdInner( - WriteCmdsView{ - .keys = keys.data(), - .vals = vals.data(), - .cmd_types = cmd_types.data(), - .cmd_cf = cmd_cf.data(), - .len = keys.size()}, - region_id, - index, - term, - tmt, - *write_result_ptr); - } - else - { - DM::WriteResult write_result; - return kvstore.handleWriteRaftCmdInner( - WriteCmdsView{ - .keys = keys.data(), - .vals = vals.data(), - .cmd_types = cmd_types.data(), - .cmd_cf = cmd_cf.data(), - .len = keys.size()}, - region_id, - index, - term, - tmt, - write_result); - } -} - -void handleApplySnapshot( - KVStore & kvstore, - metapb::Region && region, - uint64_t peer_id, - SSTViewVec snaps, - uint64_t index, - uint64_t term, - std::optional deadline_index, - TMTContext & tmt) -{ - auto new_region = kvstore.genRegionPtr(std::move(region), peer_id, index, term, tmt.getRegionTable()); - auto prehandle_result = kvstore.preHandleSnapshotToFiles(new_region, snaps, index, term, deadline_index, tmt); - kvstore.applyPreHandledSnapshot( - RegionPtrWithSnapshotFiles{new_region, std::move(prehandle_result.ingest_ids)}, - tmt); -} - -} // namespace RegionBench -} // namespace DB namespace DB { String mappedDatabase(Context & context, const String & database_name) diff --git a/dbms/src/Debug/dbgTools.h b/dbms/src/Debug/dbgTools.h index f674db8fc5e..1828c24f019 100644 --- a/dbms/src/Debug/dbgTools.h +++ b/dbms/src/Debug/dbgTools.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -32,105 +33,10 @@ struct TableInfo; namespace DB { class Context; -class Region; -using RegionPtr = std::shared_ptr; -using Regions = std::vector; class KVStore; class TMTContext; } // namespace DB -namespace DB::RegionBench -{ -RegionPtr createRegion( - TableID table_id, - RegionID region_id, - const HandleID & start, - const HandleID & end, - std::optional index = std::nullopt); - -Regions createRegions( - TableID table_id, - size_t region_num, - size_t key_num_each_region, - HandleID handle_begin, - RegionID new_region_id_begin); - -RegionPtr createRegion( - const TiDB::TableInfo & table_info, - RegionID region_id, - std::vector & start_keys, - std::vector & end_keys); - -void encodeRow(const TiDB::TableInfo & table_info, const std::vector & fields, WriteBuffer & ss); - -void insert( - const TiDB::TableInfo & table_info, - RegionID region_id, - HandleID handle_id, - ASTs::const_iterator begin, - ASTs::const_iterator end, - Context & context, - const std::optional> & tso_del = {}); - -void addRequestsToRaftCmd( - raft_cmdpb::RaftCmdRequest & request, - const TiKVKey & key, - const TiKVValue & value, - UInt64 prewrite_ts, - UInt64 commit_ts, - bool del, - const String pk = "pk"); - -void concurrentBatchInsert( - const TiDB::TableInfo & table_info, - Int64 concurrent_num, - Int64 flush_num, - Int64 batch_num, - UInt64 min_strlen, - UInt64 max_strlen, - Context & context); - -void remove(const TiDB::TableInfo & table_info, RegionID region_id, HandleID handle_id, Context & context); - -Int64 concurrentRangeOperate( - const TiDB::TableInfo & table_info, - HandleID start_handle, - HandleID end_handle, - Context & context, - Int64 magic_num, - bool del); - -Field convertField(const TiDB::ColumnInfo & column_info, const Field & field); - -TableID getTableID( - Context & context, - const std::string & database_name, - const std::string & table_name, - const std::string & partition_id); - -const TiDB::TableInfo & getTableInfo(Context & context, const String & database_name, const String & table_name); - -EngineStoreApplyRes applyWriteRaftCmd( - KVStore & kvstore, - raft_cmdpb::RaftCmdRequest && request, - UInt64 region_id, - UInt64 index, - UInt64 term, - TMTContext & tmt, - ::DB::DM::WriteResult * write_result_ptr = nullptr); - -void handleApplySnapshot( - KVStore & kvstore, - metapb::Region && region, - uint64_t peer_id, - SSTViewVec, - uint64_t index, - uint64_t term, - std::optional, - TMTContext & tmt); - -} // namespace DB::RegionBench - namespace DB { using QualifiedName = std::pair; diff --git a/dbms/src/Dictionaries/ComplexKeyCacheDictionary.cpp b/dbms/src/Dictionaries/ComplexKeyCacheDictionary.cpp deleted file mode 100644 index 83694d0da81..00000000000 --- a/dbms/src/Dictionaries/ComplexKeyCacheDictionary.cpp +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace DB -{ -namespace ErrorCodes -{ -extern const int TYPE_MISMATCH; -extern const int BAD_ARGUMENTS; -extern const int UNSUPPORTED_METHOD; -} // namespace ErrorCodes - - -inline UInt64 ComplexKeyCacheDictionary::getCellIdx(const StringRef key) const -{ - const auto hash = StringRefHash{}(key); - const auto idx = hash & size_overlap_mask; - return idx; -} - - -ComplexKeyCacheDictionary::ComplexKeyCacheDictionary( - const std::string & name, - const DictionaryStructure & dict_struct, - DictionarySourcePtr source_ptr, - const DictionaryLifetime dict_lifetime, - const size_t size) - : name{name} - , dict_struct(dict_struct) - , source_ptr{std::move(source_ptr)} - , dict_lifetime(dict_lifetime) - , size{roundUpToPowerOfTwoOrZero(std::max(size, size_t(max_collision_length)))} - , size_overlap_mask{this->size - 1} - , rnd_engine(randomSeed()) -{ - if (!this->source_ptr->supportsSelectiveLoad()) - throw Exception{ - name + ": source cannot be used with ComplexKeyCacheDictionary", - ErrorCodes::UNSUPPORTED_METHOD}; - - createAttributes(); -} - -ComplexKeyCacheDictionary::ComplexKeyCacheDictionary(const ComplexKeyCacheDictionary & other) - : ComplexKeyCacheDictionary{ - other.name, - other.dict_struct, - other.source_ptr->clone(), - other.dict_lifetime, - other.size} -{} - -void ComplexKeyCacheDictionary::getString( - const std::string & attribute_name, - const Columns & key_columns, - const DataTypes & key_types, - ColumnString * out) const -{ - dict_struct.validateKeyTypes(key_types); - - auto & attribute = getAttribute(attribute_name); - if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String)) - throw Exception{ - name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), - ErrorCodes::TYPE_MISMATCH}; - - const auto null_value = StringRef{std::get(attribute.null_values)}; - - getItemsString(attribute, key_columns, out, [&](const size_t) { return null_value; }); -} - -void ComplexKeyCacheDictionary::getString( - const std::string & attribute_name, - const Columns & key_columns, - const DataTypes & key_types, - const ColumnString * const def, - ColumnString * const out) const -{ - dict_struct.validateKeyTypes(key_types); - - auto & attribute = getAttribute(attribute_name); - if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String)) - throw Exception{ - name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), - ErrorCodes::TYPE_MISMATCH}; - - getItemsString(attribute, key_columns, out, [&](const size_t row) { return def->getDataAt(row); }); -} - -void ComplexKeyCacheDictionary::getString( - const std::string & attribute_name, - const Columns & key_columns, - const DataTypes & key_types, - const String & def, - ColumnString * const out) const -{ - dict_struct.validateKeyTypes(key_types); - - auto & attribute = getAttribute(attribute_name); - if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String)) - throw Exception{ - name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), - ErrorCodes::TYPE_MISMATCH}; - - getItemsString(attribute, key_columns, out, [&](const size_t) { return StringRef{def}; }); -} - -/// returns cell_idx (always valid for replacing), 'cell is valid' flag, 'cell is outdated' flag, -/// true false found and valid -/// false true not found (something outdated, maybe our cell) -/// false false not found (other id stored with valid data) -/// true true impossible -/// -/// todo: split this func to two: find_for_get and find_for_set -ComplexKeyCacheDictionary::FindResult ComplexKeyCacheDictionary::findCellIdx( - const StringRef & key, - const CellMetadata::time_point_t now, - const size_t hash) const -{ - auto pos = hash; - auto oldest_id = pos; - auto oldest_time = CellMetadata::time_point_t::max(); - const auto stop = pos + max_collision_length; - - for (; pos < stop; ++pos) - { - const auto cell_idx = pos & size_overlap_mask; - const auto & cell = cells[cell_idx]; - - if (cell.hash != hash || cell.key != key) - { - /// maybe we already found nearest expired cell - if (oldest_time > now && oldest_time > cell.expiresAt()) - { - oldest_time = cell.expiresAt(); - oldest_id = cell_idx; - } - - continue; - } - - if (cell.expiresAt() < now) - { - return {cell_idx, false, true}; - } - - return {cell_idx, true, false}; - } - - oldest_id &= size_overlap_mask; - return {oldest_id, false, false}; -} - -void ComplexKeyCacheDictionary::has( - const Columns & key_columns, - const DataTypes & key_types, - PaddedPODArray & out) const -{ - dict_struct.validateKeyTypes(key_types); - - /// Mapping: -> { all indices `i` of `key_columns` such that `key_columns[i]` = } - MapType> outdated_keys; - - - const auto rows_num = key_columns.front()->size(); - const auto keys_size = dict_struct.key->size(); - StringRefs keys(keys_size); - Arena temporary_keys_pool; - PODArray keys_array(rows_num); - - { - const auto now = std::chrono::system_clock::now(); - /// fetch up-to-date values, decide which ones require update - for (const auto row : ext::range(0, rows_num)) - { - const StringRef key = placeKeysInPool(row, key_columns, keys, *dict_struct.key, temporary_keys_pool); - keys_array[row] = key; - const auto find_result = findCellIdx(key, now); - const auto & cell_idx = find_result.cell_idx; - /** cell should be updated if either: - * 1. keys (or hash) do not match, - * 2. cell has expired, - * 3. explicit defaults were specified and cell was set default. */ - if (!find_result.valid) - { - outdated_keys[key].push_back(row); - } - else - { - const auto & cell = cells[cell_idx]; - out[row] = !cell.isDefault(); - } - } - } - - query_count.fetch_add(rows_num, std::memory_order_relaxed); - hit_count.fetch_add(rows_num - outdated_keys.size(), std::memory_order_release); - - if (outdated_keys.empty()) - return; - - std::vector required_rows(outdated_keys.size()); - std::transform(std::begin(outdated_keys), std::end(outdated_keys), std::begin(required_rows), [](auto & pair) { - return pair.getMapped().front(); - }); - - /// request new values - update( - key_columns, - keys_array, - required_rows, - [&](const StringRef key, const auto) { - for (const auto out_idx : outdated_keys[key]) - out[out_idx] = true; - }, - [&](const StringRef key, const auto) { - for (const auto out_idx : outdated_keys[key]) - out[out_idx] = false; - }); -} - -void ComplexKeyCacheDictionary::createAttributes() -{ - const auto attributes_size = dict_struct.attributes.size(); - attributes.reserve(attributes_size); - - bytes_allocated += size * sizeof(CellMetadata); - bytes_allocated += attributes_size * sizeof(attributes.front()); - - for (const auto & attribute : dict_struct.attributes) - { - attribute_index_by_name.emplace(attribute.name, attributes.size()); - attributes.push_back(createAttributeWithType(attribute.underlying_type, attribute.null_value)); - - if (attribute.hierarchical) - throw Exception{ - name + ": hierarchical attributes not supported for dictionary of type " + getTypeName(), - ErrorCodes::TYPE_MISMATCH}; - } -} - -ComplexKeyCacheDictionary::Attribute & ComplexKeyCacheDictionary::getAttribute(const std::string & attribute_name) const -{ - const auto it = attribute_index_by_name.find(attribute_name); - if (it == std::end(attribute_index_by_name)) - throw Exception{name + ": no such attribute '" + attribute_name + "'", ErrorCodes::BAD_ARGUMENTS}; - - return attributes[it->second]; -} - -StringRef ComplexKeyCacheDictionary::allocKey(const size_t row, const Columns & key_columns, StringRefs & keys) const -{ - if (key_size_is_fixed) - return placeKeysInFixedSizePool(row, key_columns); - - return placeKeysInPool(row, key_columns, keys, *dict_struct.key, *keys_pool); -} - -void ComplexKeyCacheDictionary::freeKey(const StringRef key) const -{ - if (key_size_is_fixed) - fixed_size_keys_pool->free(const_cast(key.data)); - else - keys_pool->free(const_cast(key.data), key.size); -} - -template -StringRef ComplexKeyCacheDictionary::placeKeysInPool( - const size_t row, - const Columns & key_columns, - StringRefs & keys, - const std::vector & key_attributes, - Pool & pool) -{ - const auto keys_size = key_columns.size(); - size_t sum_keys_size{}; - - for (size_t j = 0; j < keys_size; ++j) - { - keys[j] = key_columns[j]->getDataAt(row); - sum_keys_size += keys[j].size; - if (key_attributes[j].underlying_type == AttributeUnderlyingType::String) - sum_keys_size += sizeof(size_t) + 1; - } - - auto place = pool.alloc(sum_keys_size); - - auto key_start = place; - for (size_t j = 0; j < keys_size; ++j) - { - if (key_attributes[j].underlying_type == AttributeUnderlyingType::String) - { - auto start = key_start; - auto key_size = keys[j].size + 1; - memcpy(key_start, &key_size, sizeof(size_t)); - key_start += sizeof(size_t); - memcpy(key_start, keys[j].data, keys[j].size); - key_start += keys[j].size; - *key_start = '\0'; - ++key_start; - keys[j].data = start; - keys[j].size += sizeof(size_t) + 1; - } - else - { - memcpy(key_start, keys[j].data, keys[j].size); - keys[j].data = key_start; - key_start += keys[j].size; - } - } - - return {place, sum_keys_size}; -} - -/// Explicit instantiations. - -template StringRef ComplexKeyCacheDictionary::placeKeysInPool( - const size_t row, - const Columns & key_columns, - StringRefs & keys, - const std::vector & key_attributes, - Arena & pool); - -template StringRef ComplexKeyCacheDictionary::placeKeysInPool( - const size_t row, - const Columns & key_columns, - StringRefs & keys, - const std::vector & key_attributes, - ArenaWithFreeLists & pool); - - -StringRef ComplexKeyCacheDictionary::placeKeysInFixedSizePool(const size_t row, const Columns & key_columns) const -{ - auto * const res = fixed_size_keys_pool->alloc(); - auto * place = res; - - for (const auto & key_column : key_columns) - { - const StringRef key = key_column->getDataAt(row); - memcpy(place, key.data, key.size); - place += key.size; - } - - return {res, key_size}; -} - -StringRef ComplexKeyCacheDictionary::copyIntoArena(StringRef src, Arena & arena) -{ - char * allocated = arena.alloc(src.size); - memcpy(allocated, src.data, src.size); - return {allocated, src.size}; -} - -StringRef ComplexKeyCacheDictionary::copyKey(const StringRef key) const -{ - auto * const res = key_size_is_fixed ? fixed_size_keys_pool->alloc() : keys_pool->alloc(key.size); - memcpy(res, key.data, key.size); - - return {res, key.size}; -} - -bool ComplexKeyCacheDictionary::isEmptyCell(const UInt64 idx) const -{ - return ( - cells[idx].key == StringRef{} - && (idx != zero_cell_idx - || cells[idx].data == ext::safe_bit_cast(CellMetadata::time_point_t()))); -} - -} // namespace DB diff --git a/dbms/src/Dictionaries/ComplexKeyCacheDictionary.h b/dbms/src/Dictionaries/ComplexKeyCacheDictionary.h deleted file mode 100644 index 22f7f4453c0..00000000000 --- a/dbms/src/Dictionaries/ComplexKeyCacheDictionary.h +++ /dev/null @@ -1,689 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace DB -{ -class ComplexKeyCacheDictionary final : public IDictionaryBase -{ -public: - ComplexKeyCacheDictionary( - const std::string & name, - const DictionaryStructure & dict_struct, - DictionarySourcePtr source_ptr, - const DictionaryLifetime dict_lifetime, - const size_t size); - - ComplexKeyCacheDictionary(const ComplexKeyCacheDictionary & other); - - std::string getKeyDescription() const { return key_description; }; - - std::exception_ptr getCreationException() const override { return {}; } - - std::string getName() const override { return name; } - - std::string getTypeName() const override { return "ComplexKeyCache"; } - - size_t getBytesAllocated() const override - { - return bytes_allocated + (key_size_is_fixed ? fixed_size_keys_pool->size() : keys_pool->size()) - + (string_arena ? string_arena->size() : 0); - } - - size_t getQueryCount() const override { return query_count.load(std::memory_order_relaxed); } - - double getHitRate() const override - { - return static_cast(hit_count.load(std::memory_order_acquire)) - / query_count.load(std::memory_order_relaxed); - } - - size_t getElementCount() const override { return element_count.load(std::memory_order_relaxed); } - - double getLoadFactor() const override - { - return static_cast(element_count.load(std::memory_order_relaxed)) / size; - } - - bool isCached() const override { return true; } - - std::unique_ptr clone() const override - { - return std::make_unique(*this); - } - - const IDictionarySource * getSource() const override { return source_ptr.get(); } - - const DictionaryLifetime & getLifetime() const override { return dict_lifetime; } - - const DictionaryStructure & getStructure() const override { return dict_struct; } - - std::chrono::time_point getCreationTime() const override { return creation_time; } - - bool isInjective(const std::string & attribute_name) const override - { - return dict_struct.attributes[&getAttribute(attribute_name) - attributes.data()].injective; - } - -/// In all functions below, key_columns must be full (non-constant) columns. -/// See the requirement in IDataType.h for text-serialization functions. -#define DECLARE(TYPE) \ - void get##TYPE( \ - const std::string & attribute_name, \ - const Columns & key_columns, \ - const DataTypes & key_types, \ - PaddedPODArray & out) const; - DECLARE(UInt8) - DECLARE(UInt16) - DECLARE(UInt32) - DECLARE(UInt64) - DECLARE(UInt128) - DECLARE(Int8) - DECLARE(Int16) - DECLARE(Int32) - DECLARE(Int64) - DECLARE(Float32) - DECLARE(Float64) -#undef DECLARE - - void getString( - const std::string & attribute_name, - const Columns & key_columns, - const DataTypes & key_types, - ColumnString * out) const; - -#define DECLARE(TYPE) \ - void get##TYPE( \ - const std::string & attribute_name, \ - const Columns & key_columns, \ - const DataTypes & key_types, \ - const PaddedPODArray & def, \ - PaddedPODArray & out) const; - DECLARE(UInt8) - DECLARE(UInt16) - DECLARE(UInt32) - DECLARE(UInt64) - DECLARE(UInt128) - DECLARE(Int8) - DECLARE(Int16) - DECLARE(Int32) - DECLARE(Int64) - DECLARE(Float32) - DECLARE(Float64) -#undef DECLARE - - void getString( - const std::string & attribute_name, - const Columns & key_columns, - const DataTypes & key_types, - const ColumnString * const def, - ColumnString * const out) const; - -#define DECLARE(TYPE) \ - void get##TYPE( \ - const std::string & attribute_name, \ - const Columns & key_columns, \ - const DataTypes & key_types, \ - const TYPE def, \ - PaddedPODArray & out) const; - DECLARE(UInt8) - DECLARE(UInt16) - DECLARE(UInt32) - DECLARE(UInt64) - DECLARE(UInt128) - DECLARE(Int8) - DECLARE(Int16) - DECLARE(Int32) - DECLARE(Int64) - DECLARE(Float32) - DECLARE(Float64) -#undef DECLARE - - void getString( - const std::string & attribute_name, - const Columns & key_columns, - const DataTypes & key_types, - const String & def, - ColumnString * const out) const; - - void has(const Columns & key_columns, const DataTypes & key_types, PaddedPODArray & out) const; - -private: - template - using MapType = HashMapWithSavedHash; - template - using ContainerType = Value[]; - template - using ContainerPtrType = std::unique_ptr>; - - struct CellMetadata final - { - using time_point_t = std::chrono::system_clock::time_point; - using time_point_rep_t = time_point_t::rep; - using time_point_urep_t = std::make_unsigned_t; - - static constexpr UInt64 EXPIRES_AT_MASK = std::numeric_limits::max(); - static constexpr UInt64 IS_DEFAULT_MASK = ~EXPIRES_AT_MASK; - - StringRef key; - decltype(StringRefHash{}(key)) hash; - /// Stores both expiration time and `is_default` flag in the most significant bit - time_point_urep_t data; - - /// Sets expiration time, resets `is_default` flag to false - time_point_t expiresAt() const { return ext::safe_bit_cast(data & EXPIRES_AT_MASK); } - void setExpiresAt(const time_point_t & t) { data = ext::safe_bit_cast(t); } - - bool isDefault() const { return (data & IS_DEFAULT_MASK) == IS_DEFAULT_MASK; } - void setDefault() { data |= IS_DEFAULT_MASK; } - }; - - struct Attribute final - { - AttributeUnderlyingType type; - std::tuple - null_values; - std::tuple< - ContainerPtrType, - ContainerPtrType, - ContainerPtrType, - ContainerPtrType, - ContainerPtrType, - ContainerPtrType, - ContainerPtrType, - ContainerPtrType, - ContainerPtrType, - ContainerPtrType, - ContainerPtrType, - ContainerPtrType> - arrays; - }; - - void createAttributes(); - - Attribute createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value); - - template - void getItemsNumber( - Attribute & attribute, - const Columns & key_columns, - PaddedPODArray & out, - DefaultGetter && get_default) const - { - if (false) {} -#define DISPATCH(TYPE) \ - else if (attribute.type == AttributeUnderlyingType::TYPE) \ - getItemsNumberImpl(attribute, key_columns, out, std::forward(get_default)); - DISPATCH(UInt8) - DISPATCH(UInt16) - DISPATCH(UInt32) - DISPATCH(UInt64) - DISPATCH(UInt128) - DISPATCH(Int8) - DISPATCH(Int16) - DISPATCH(Int32) - DISPATCH(Int64) - DISPATCH(Float32) - DISPATCH(Float64) -#undef DISPATCH - else throw Exception("Unexpected type of attribute: " + toString(attribute.type), ErrorCodes::LOGICAL_ERROR); - }; - - template - void getItemsNumberImpl( - Attribute & attribute, - const Columns & key_columns, - PaddedPODArray & out, - DefaultGetter && get_default) const - { - /// Mapping: -> { all indices `i` of `key_columns` such that `key_columns[i]` = } - MapType> outdated_keys; - auto & attribute_array = std::get>(attribute.arrays); - - const auto rows_num = key_columns.front()->size(); - const auto keys_size = dict_struct.key->size(); - StringRefs keys(keys_size); - Arena temporary_keys_pool; - PODArray keys_array(rows_num); - - { - const auto now = std::chrono::system_clock::now(); - /// fetch up-to-date values, decide which ones require update - for (const auto row : ext::range(0, rows_num)) - { - const StringRef key = placeKeysInPool(row, key_columns, keys, *dict_struct.key, temporary_keys_pool); - keys_array[row] = key; - const auto find_result = findCellIdx(key, now); - - /** cell should be updated if either: - * 1. keys (or hash) do not match, - * 2. cell has expired, - * 3. explicit defaults were specified and cell was set default. */ - - if (!find_result.valid) - { - outdated_keys[key].push_back(row); - } - else - { - const auto & cell_idx = find_result.cell_idx; - const auto & cell = cells[cell_idx]; - out[row] = cell.isDefault() ? get_default(row) : static_cast(attribute_array[cell_idx]); - } - } - } - query_count.fetch_add(rows_num, std::memory_order_relaxed); - hit_count.fetch_add(rows_num - outdated_keys.size(), std::memory_order_release); - - if (outdated_keys.empty()) - return; - - std::vector required_rows(outdated_keys.size()); - std::transform(std::begin(outdated_keys), std::end(outdated_keys), std::begin(required_rows), [](auto & pair) { - return pair.getMapped().front(); - }); - - /// request new values - update( - key_columns, - keys_array, - required_rows, - [&](const StringRef key, const size_t cell_idx) { - for (const auto row : outdated_keys[key]) - out[row] = static_cast(attribute_array[cell_idx]); - }, - [&](const StringRef key, const size_t) { - for (const auto row : outdated_keys[key]) - out[row] = get_default(row); - }); - }; - - template - void getItemsString( - Attribute & attribute, - const Columns & key_columns, - ColumnString * out, - DefaultGetter && get_default) const - { - const auto rows_num = key_columns.front()->size(); - /// save on some allocations - out->getOffsets().reserve(rows_num); - - const auto keys_size = dict_struct.key->size(); - StringRefs keys(keys_size); - Arena temporary_keys_pool; - - auto & attribute_array = std::get>(attribute.arrays); - - auto found_outdated_values = false; - - /// perform optimistic version, fallback to pessimistic if failed - { - const auto now = std::chrono::system_clock::now(); - /// fetch up-to-date values, discard on fail - for (const auto row : ext::range(0, rows_num)) - { - const StringRef key = placeKeysInPool(row, key_columns, keys, *dict_struct.key, temporary_keys_pool); - SCOPE_EXIT(temporary_keys_pool.rollback(key.size)); - const auto find_result = findCellIdx(key, now); - - if (!find_result.valid) - { - found_outdated_values = true; - break; - } - else - { - const auto & cell_idx = find_result.cell_idx; - const auto & cell = cells[cell_idx]; - const auto string_ref = cell.isDefault() ? get_default(row) : attribute_array[cell_idx]; - out->insertData(string_ref.data, string_ref.size); - } - } - } - - /// optimistic code completed successfully - if (!found_outdated_values) - { - query_count.fetch_add(rows_num, std::memory_order_relaxed); - hit_count.fetch_add(rows_num, std::memory_order_release); - return; - } - - /// now onto the pessimistic one, discard possible partial results from the optimistic path - out->getChars().resize_assume_reserved(0); - out->getOffsets().resize_assume_reserved(0); - - /// Mapping: -> { all indices `i` of `key_columns` such that `key_columns[i]` = } - MapType> outdated_keys; - /// we are going to store every string separately - MapType map; - PODArray keys_array(rows_num); - - size_t total_length = 0; - { - const auto now = std::chrono::system_clock::now(); - for (const auto row : ext::range(0, rows_num)) - { - const StringRef key = placeKeysInPool(row, key_columns, keys, *dict_struct.key, temporary_keys_pool); - keys_array[row] = key; - const auto find_result = findCellIdx(key, now); - - if (!find_result.valid) - { - outdated_keys[key].push_back(row); - } - else - { - const auto & cell_idx = find_result.cell_idx; - const auto & cell = cells[cell_idx]; - const auto string_ref = cell.isDefault() ? get_default(row) : attribute_array[cell_idx]; - - if (!cell.isDefault()) - map[key] = copyIntoArena(string_ref, temporary_keys_pool); - - total_length += string_ref.size + 1; - } - } - } - - query_count.fetch_add(rows_num, std::memory_order_relaxed); - hit_count.fetch_add(rows_num - outdated_keys.size(), std::memory_order_release); - - /// request new values - if (!outdated_keys.empty()) - { - std::vector required_rows(outdated_keys.size()); - std::transform( - std::begin(outdated_keys), - std::end(outdated_keys), - std::begin(required_rows), - [](auto & pair) { return pair.getMapped().front(); }); - - update( - key_columns, - keys_array, - required_rows, - [&](const StringRef key, const size_t cell_idx) { - const StringRef attribute_value = attribute_array[cell_idx]; - - /// We must copy key and value to own memory, because it may be replaced with another - /// in next iterations of inner loop of update. - const StringRef copied_key = copyIntoArena(key, temporary_keys_pool); - const StringRef copied_value = copyIntoArena(attribute_value, temporary_keys_pool); - - map[copied_key] = copied_value; - total_length += (attribute_value.size + 1) * outdated_keys[key].size(); - }, - [&](const StringRef key, const size_t) { - for (const auto row : outdated_keys[key]) - total_length += get_default(row).size + 1; - }); - } - - out->getChars().reserve(total_length); - - for (const auto row : ext::range(0, ext::size(keys_array))) - { - const StringRef key = keys_array[row]; - auto * const it = map.find(key); - const auto string_ref = it != std::end(map) ? it->getMapped() : get_default(row); - out->insertData(string_ref.data, string_ref.size); - } - }; - - template - void update( - const Columns & in_key_columns, - const PODArray & in_keys, - const std::vector & in_requested_rows, - PresentKeyHandler && on_cell_updated, - AbsentKeyHandler && on_key_not_found) const - { - MapType remaining_keys{in_requested_rows.size()}; - for (const auto row : in_requested_rows) - remaining_keys.insert({in_keys[row], false}); - - std::uniform_int_distribution distribution(dict_lifetime.min_sec, dict_lifetime.max_sec); - - { - Stopwatch watch; - auto stream = source_ptr->loadKeys(in_key_columns, in_requested_rows); - stream->readPrefix(); - - const auto keys_size = dict_struct.key->size(); - StringRefs keys(keys_size); - - const auto attributes_size = attributes.size(); - const auto now = std::chrono::system_clock::now(); - - while (const auto block = stream->read()) - { - /// cache column pointers - const auto key_columns = ext::map(ext::range(0, keys_size), [&](const size_t attribute_idx) { - return block.safeGetByPosition(attribute_idx).column; - }); - - const auto attribute_columns - = ext::map(ext::range(0, attributes_size), [&](const size_t attribute_idx) { - return block.safeGetByPosition(keys_size + attribute_idx).column; - }); - - const auto rows_num = block.rows(); - - for (const auto row : ext::range(0, rows_num)) - { - auto key = allocKey(row, key_columns, keys); - const auto hash = StringRefHash{}(key); - const auto find_result = findCellIdx(key, now, hash); - const auto & cell_idx = find_result.cell_idx; - auto & cell = cells[cell_idx]; - - for (const auto attribute_idx : ext::range(0, attributes.size())) - { - const auto & attribute_column = *attribute_columns[attribute_idx]; - auto & attribute = attributes[attribute_idx]; - - setAttributeValue(attribute, cell_idx, attribute_column[row]); - } - - /// if cell id is zero and zero does not map to this cell, then the cell is unused - if (cell.key == StringRef{} && cell_idx != zero_cell_idx) - element_count.fetch_add(1, std::memory_order_relaxed); - - /// handle memory allocated for old key - if (key == cell.key) - { - freeKey(key); - key = cell.key; - } - else - { - /// new key is different from the old one - if (cell.key.data) - freeKey(cell.key); - - cell.key = key; - } - - cell.hash = hash; - - if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0) - cell.setExpiresAt( - std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)}); - else - cell.setExpiresAt(std::chrono::time_point::max()); - - /// inform caller - on_cell_updated(key, cell_idx); - /// mark corresponding id as found - remaining_keys[key] = true; - } - } - - stream->readSuffix(); - } - - const auto now = std::chrono::system_clock::now(); - - /// Check which ids have not been found and require setting null_value - for (const auto & key_found_pair : remaining_keys) - { - if (key_found_pair.getMapped()) - { - continue; - } - - auto key = key_found_pair.getKey(); - const auto hash = StringRefHash{}(key); - const auto find_result = findCellIdx(key, now, hash); - const auto & cell_idx = find_result.cell_idx; - auto & cell = cells[cell_idx]; - - /// Set null_value for each attribute - for (auto & attribute : attributes) - setDefaultAttributeValue(attribute, cell_idx); - - /// Check if cell had not been occupied before and increment element counter if it hadn't - if (cell.key == StringRef{} && cell_idx != zero_cell_idx) - element_count.fetch_add(1, std::memory_order_relaxed); - - if (key == cell.key) - key = cell.key; - else - { - if (cell.key.data) - freeKey(cell.key); - - /// copy key from temporary pool - key = copyKey(key); - cell.key = key; - } - - cell.hash = hash; - - if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0) - cell.setExpiresAt(std::chrono::system_clock::now() + std::chrono::seconds{distribution(rnd_engine)}); - else - cell.setExpiresAt(std::chrono::time_point::max()); - - cell.setDefault(); - - /// inform caller that the cell has not been found - on_key_not_found(key, cell_idx); - } - }; - - UInt64 getCellIdx(const StringRef key) const; - - void setDefaultAttributeValue(Attribute & attribute, const size_t idx) const; - - void setAttributeValue(Attribute & attribute, const size_t idx, const Field & value) const; - - Attribute & getAttribute(const std::string & attribute_name) const; - - StringRef allocKey(const size_t row, const Columns & key_columns, StringRefs & keys) const; - - void freeKey(const StringRef key) const; - - template - static StringRef placeKeysInPool( - const size_t row, - const Columns & key_columns, - StringRefs & keys, - const std::vector & key_attributes, - Arena & pool); - - StringRef placeKeysInFixedSizePool(const size_t row, const Columns & key_columns) const; - - static StringRef copyIntoArena(StringRef src, Arena & arena); - StringRef copyKey(const StringRef key) const; - - struct FindResult - { - const size_t cell_idx; - const bool valid; - const bool outdated; - }; - - FindResult findCellIdx(const StringRef & key, const CellMetadata::time_point_t now, const size_t hash) const; - FindResult findCellIdx(const StringRef & key, const CellMetadata::time_point_t now) const - { - const auto hash = StringRefHash{}(key); - return findCellIdx(key, now, hash); - }; - - bool isEmptyCell(const UInt64 idx) const; - - const std::string name; - const DictionaryStructure dict_struct; - const DictionarySourcePtr source_ptr; - const DictionaryLifetime dict_lifetime; - const std::string key_description{dict_struct.getKeyDescription()}; - - mutable std::shared_mutex rw_lock; - - /// Actual size will be increased to match power of 2 - const size_t size; - - /// all bits to 1 mask (size - 1) (0b1000 - 1 = 0b111) - const size_t size_overlap_mask; - - /// Max tries to find cell, overlaped with mask: if size = 16 and start_cell=10: will try cells: 10,11,12,13,14,15,0,1,2,3 - static constexpr size_t max_collision_length = 10; - - const UInt64 zero_cell_idx{getCellIdx(StringRef{})}; - std::map attribute_index_by_name; - mutable std::vector attributes; - mutable std::vector cells{size}; - const bool key_size_is_fixed{dict_struct.isKeySizeFixed()}; - size_t key_size{key_size_is_fixed ? dict_struct.getKeySize() : 0}; - std::unique_ptr keys_pool - = key_size_is_fixed ? nullptr : std::make_unique(); - std::unique_ptr fixed_size_keys_pool - = key_size_is_fixed ? std::make_unique(key_size) : nullptr; - std::unique_ptr string_arena; - - mutable pcg64 rnd_engine; - - mutable size_t bytes_allocated = 0; - mutable std::atomic element_count{0}; - mutable std::atomic hit_count{0}; - mutable std::atomic query_count{0}; - - const std::chrono::time_point creation_time = std::chrono::system_clock::now(); -}; -} // namespace DB diff --git a/dbms/src/Dictionaries/ComplexKeyCacheDictionary_createAttributeWithType.cpp b/dbms/src/Dictionaries/ComplexKeyCacheDictionary_createAttributeWithType.cpp deleted file mode 100644 index 129f6e0eacf..00000000000 --- a/dbms/src/Dictionaries/ComplexKeyCacheDictionary_createAttributeWithType.cpp +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -namespace DB -{ - -ComplexKeyCacheDictionary::Attribute ComplexKeyCacheDictionary::createAttributeWithType( - const AttributeUnderlyingType type, - const Field & null_value) -{ - Attribute attr{type, {}, {}}; - - switch (type) - { - case AttributeUnderlyingType::UInt8: - std::get(attr.null_values) = null_value.get(); - std::get>(attr.arrays) = std::make_unique>(size); - bytes_allocated += size * sizeof(UInt8); - break; - case AttributeUnderlyingType::UInt16: - std::get(attr.null_values) = null_value.get(); - std::get>(attr.arrays) = std::make_unique>(size); - bytes_allocated += size * sizeof(UInt16); - break; - case AttributeUnderlyingType::UInt32: - std::get(attr.null_values) = null_value.get(); - std::get>(attr.arrays) = std::make_unique>(size); - bytes_allocated += size * sizeof(UInt32); - break; - case AttributeUnderlyingType::UInt64: - std::get(attr.null_values) = null_value.get(); - std::get>(attr.arrays) = std::make_unique>(size); - bytes_allocated += size * sizeof(UInt64); - break; - case AttributeUnderlyingType::UInt128: - std::get(attr.null_values) = null_value.get(); - std::get>(attr.arrays) = std::make_unique>(size); - bytes_allocated += size * sizeof(UInt128); - break; - case AttributeUnderlyingType::Int8: - std::get(attr.null_values) = null_value.get(); - std::get>(attr.arrays) = std::make_unique>(size); - bytes_allocated += size * sizeof(Int8); - break; - case AttributeUnderlyingType::Int16: - std::get(attr.null_values) = null_value.get(); - std::get>(attr.arrays) = std::make_unique>(size); - bytes_allocated += size * sizeof(Int16); - break; - case AttributeUnderlyingType::Int32: - std::get(attr.null_values) = null_value.get(); - std::get>(attr.arrays) = std::make_unique>(size); - bytes_allocated += size * sizeof(Int32); - break; - case AttributeUnderlyingType::Int64: - std::get(attr.null_values) = null_value.get(); - std::get>(attr.arrays) = std::make_unique>(size); - bytes_allocated += size * sizeof(Int64); - break; - case AttributeUnderlyingType::Float32: - std::get(attr.null_values) = null_value.get(); - std::get>(attr.arrays) = std::make_unique>(size); - bytes_allocated += size * sizeof(Float32); - break; - case AttributeUnderlyingType::Float64: - std::get(attr.null_values) = null_value.get(); - std::get>(attr.arrays) = std::make_unique>(size); - bytes_allocated += size * sizeof(Float64); - break; - case AttributeUnderlyingType::String: - std::get(attr.null_values) = null_value.get(); - std::get>(attr.arrays) = std::make_unique>(size); - bytes_allocated += size * sizeof(StringRef); - if (!string_arena) - string_arena = std::make_unique(); - break; - } - - return attr; -} - -} // namespace DB diff --git a/dbms/src/Dictionaries/ComplexKeyCacheDictionary_generate1.cpp b/dbms/src/Dictionaries/ComplexKeyCacheDictionary_generate1.cpp deleted file mode 100644 index 2b308c0b180..00000000000 --- a/dbms/src/Dictionaries/ComplexKeyCacheDictionary_generate1.cpp +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ComplexKeyCacheDictionary.h" - -namespace DB -{ -namespace ErrorCodes -{ -extern const int TYPE_MISMATCH; -} - -#define DECLARE(TYPE) \ - void ComplexKeyCacheDictionary::get##TYPE( \ - const std::string & attribute_name, \ - const Columns & key_columns, \ - const DataTypes & key_types, \ - PaddedPODArray & out) const \ - { \ - dict_struct.validateKeyTypes(key_types); \ - \ - auto & attribute = getAttribute(attribute_name); \ - if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \ - throw Exception{ \ - name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \ - ErrorCodes::TYPE_MISMATCH}; \ - \ - const auto null_value = std::get(attribute.null_values); \ - \ - getItemsNumber(attribute, key_columns, out, [&](const size_t) { return null_value; }); \ - } -DECLARE(UInt8) -DECLARE(UInt16) -DECLARE(UInt32) -DECLARE(UInt64) -DECLARE(UInt128) -DECLARE(Int8) -DECLARE(Int16) -DECLARE(Int32) -DECLARE(Int64) -DECLARE(Float32) -DECLARE(Float64) -#undef DECLARE -} // namespace DB diff --git a/dbms/src/Dictionaries/ComplexKeyCacheDictionary_generate2.cpp b/dbms/src/Dictionaries/ComplexKeyCacheDictionary_generate2.cpp deleted file mode 100644 index 85733a92b0e..00000000000 --- a/dbms/src/Dictionaries/ComplexKeyCacheDictionary_generate2.cpp +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ComplexKeyCacheDictionary.h" - -namespace DB -{ -namespace ErrorCodes -{ -extern const int TYPE_MISMATCH; -} - -#define DECLARE(TYPE) \ - void ComplexKeyCacheDictionary::get##TYPE( \ - const std::string & attribute_name, \ - const Columns & key_columns, \ - const DataTypes & key_types, \ - const PaddedPODArray & def, \ - PaddedPODArray & out) const \ - { \ - dict_struct.validateKeyTypes(key_types); \ - \ - auto & attribute = getAttribute(attribute_name); \ - if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \ - throw Exception{ \ - name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \ - ErrorCodes::TYPE_MISMATCH}; \ - \ - getItemsNumber(attribute, key_columns, out, [&](const size_t row) { return def[row]; }); \ - } -DECLARE(UInt8) -DECLARE(UInt16) -DECLARE(UInt32) -DECLARE(UInt64) -DECLARE(UInt128) -DECLARE(Int8) -DECLARE(Int16) -DECLARE(Int32) -DECLARE(Int64) -DECLARE(Float32) -DECLARE(Float64) -#undef DECLARE -} // namespace DB diff --git a/dbms/src/Dictionaries/ComplexKeyCacheDictionary_generate3.cpp b/dbms/src/Dictionaries/ComplexKeyCacheDictionary_generate3.cpp deleted file mode 100644 index 5097ef4b1a8..00000000000 --- a/dbms/src/Dictionaries/ComplexKeyCacheDictionary_generate3.cpp +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ComplexKeyCacheDictionary.h" - -namespace DB -{ -namespace ErrorCodes -{ -extern const int TYPE_MISMATCH; -} - -#define DECLARE(TYPE) \ - void ComplexKeyCacheDictionary::get##TYPE( \ - const std::string & attribute_name, \ - const Columns & key_columns, \ - const DataTypes & key_types, \ - const TYPE def, \ - PaddedPODArray & out) const \ - { \ - dict_struct.validateKeyTypes(key_types); \ - \ - auto & attribute = getAttribute(attribute_name); \ - if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \ - throw Exception{ \ - name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \ - ErrorCodes::TYPE_MISMATCH}; \ - \ - getItemsNumber(attribute, key_columns, out, [&](const size_t) { return def; }); \ - } -DECLARE(UInt8) -DECLARE(UInt16) -DECLARE(UInt32) -DECLARE(UInt64) -DECLARE(UInt128) -DECLARE(Int8) -DECLARE(Int16) -DECLARE(Int32) -DECLARE(Int64) -DECLARE(Float32) -DECLARE(Float64) -#undef DECLARE -} // namespace DB diff --git a/dbms/src/Dictionaries/ComplexKeyCacheDictionary_setAttributeValue.cpp b/dbms/src/Dictionaries/ComplexKeyCacheDictionary_setAttributeValue.cpp deleted file mode 100644 index 9b111e17727..00000000000 --- a/dbms/src/Dictionaries/ComplexKeyCacheDictionary_setAttributeValue.cpp +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -namespace DB -{ - -void ComplexKeyCacheDictionary::setAttributeValue(Attribute & attribute, const size_t idx, const Field & value) const -{ - switch (attribute.type) - { - case AttributeUnderlyingType::UInt8: - std::get>(attribute.arrays)[idx] = value.get(); - break; - case AttributeUnderlyingType::UInt16: - std::get>(attribute.arrays)[idx] = value.get(); - break; - case AttributeUnderlyingType::UInt32: - std::get>(attribute.arrays)[idx] = value.get(); - break; - case AttributeUnderlyingType::UInt64: - std::get>(attribute.arrays)[idx] = value.get(); - break; - case AttributeUnderlyingType::UInt128: - std::get>(attribute.arrays)[idx] = value.get(); - break; - case AttributeUnderlyingType::Int8: - std::get>(attribute.arrays)[idx] = value.get(); - break; - case AttributeUnderlyingType::Int16: - std::get>(attribute.arrays)[idx] = value.get(); - break; - case AttributeUnderlyingType::Int32: - std::get>(attribute.arrays)[idx] = value.get(); - break; - case AttributeUnderlyingType::Int64: - std::get>(attribute.arrays)[idx] = value.get(); - break; - case AttributeUnderlyingType::Float32: - std::get>(attribute.arrays)[idx] = value.get(); - break; - case AttributeUnderlyingType::Float64: - std::get>(attribute.arrays)[idx] = value.get(); - break; - case AttributeUnderlyingType::String: - { - const auto & string = value.get(); - auto & string_ref = std::get>(attribute.arrays)[idx]; - const auto & null_value_ref = std::get(attribute.null_values); - - /// free memory unless it points to a null_value - if (string_ref.data && string_ref.data != null_value_ref.data()) - string_arena->free(const_cast(string_ref.data), string_ref.size); - - const auto size = string.size(); - if (size != 0) - { - auto string_ptr = string_arena->alloc(size + 1); - std::copy(string.data(), string.data() + size + 1, string_ptr); - string_ref = StringRef{string_ptr, size}; - } - else - string_ref = {}; - - break; - } - } -} - -} // namespace DB diff --git a/dbms/src/Dictionaries/ComplexKeyCacheDictionary_setDefaultAttributeValue.cpp b/dbms/src/Dictionaries/ComplexKeyCacheDictionary_setDefaultAttributeValue.cpp deleted file mode 100644 index 2b97546164f..00000000000 --- a/dbms/src/Dictionaries/ComplexKeyCacheDictionary_setDefaultAttributeValue.cpp +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -namespace DB -{ - -void ComplexKeyCacheDictionary::setDefaultAttributeValue(Attribute & attribute, const size_t idx) const -{ - switch (attribute.type) - { - case AttributeUnderlyingType::UInt8: - std::get>(attribute.arrays)[idx] = std::get(attribute.null_values); - break; - case AttributeUnderlyingType::UInt16: - std::get>(attribute.arrays)[idx] = std::get(attribute.null_values); - break; - case AttributeUnderlyingType::UInt32: - std::get>(attribute.arrays)[idx] = std::get(attribute.null_values); - break; - case AttributeUnderlyingType::UInt64: - std::get>(attribute.arrays)[idx] = std::get(attribute.null_values); - break; - case AttributeUnderlyingType::UInt128: - std::get>(attribute.arrays)[idx] = std::get(attribute.null_values); - break; - case AttributeUnderlyingType::Int8: - std::get>(attribute.arrays)[idx] = std::get(attribute.null_values); - break; - case AttributeUnderlyingType::Int16: - std::get>(attribute.arrays)[idx] = std::get(attribute.null_values); - break; - case AttributeUnderlyingType::Int32: - std::get>(attribute.arrays)[idx] = std::get(attribute.null_values); - break; - case AttributeUnderlyingType::Int64: - std::get>(attribute.arrays)[idx] = std::get(attribute.null_values); - break; - case AttributeUnderlyingType::Float32: - std::get>(attribute.arrays)[idx] = std::get(attribute.null_values); - break; - case AttributeUnderlyingType::Float64: - std::get>(attribute.arrays)[idx] = std::get(attribute.null_values); - break; - case AttributeUnderlyingType::String: - { - const auto & null_value_ref = std::get(attribute.null_values); - auto & string_ref = std::get>(attribute.arrays)[idx]; - - if (string_ref.data != null_value_ref.data()) - { - if (string_ref.data) - string_arena->free(const_cast(string_ref.data), string_ref.size); - - string_ref = StringRef{null_value_ref}; - } - - break; - } - } -} - -} // namespace DB diff --git a/dbms/src/Flash/Coprocessor/TablesRegionsInfo.cpp b/dbms/src/Flash/Coprocessor/TablesRegionsInfo.cpp index be11424f14b..3c0dfe9041d 100644 --- a/dbms/src/Flash/Coprocessor/TablesRegionsInfo.cpp +++ b/dbms/src/Flash/Coprocessor/TablesRegionsInfo.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include namespace DB diff --git a/dbms/src/Flash/Mpp/tests/gtest_mpp_task_manager.cpp b/dbms/src/Flash/Mpp/tests/gtest_mpp_task_manager.cpp index b1a915395af..6e96039898c 100644 --- a/dbms/src/Flash/Mpp/tests/gtest_mpp_task_manager.cpp +++ b/dbms/src/Flash/Mpp/tests/gtest_mpp_task_manager.cpp @@ -15,13 +15,12 @@ #include #include #include +#include #include #include #include #include -#include "Server/RaftConfigParser.h" - namespace DB { namespace tests diff --git a/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp b/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp index 6260a55f8f5..fbe5a1b305a 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp @@ -209,7 +209,7 @@ bool PhysicalJoinV2::isSupported(const tipb::Join & join) switch (tiflash_join.kind) { case Inner: - //case LeftOuter: + case LeftOuter: //case Semi: //case Anti: //case RightOuter: diff --git a/dbms/src/Flash/Planner/tests/gtest_physical_plan.cpp b/dbms/src/Flash/Planner/tests/gtest_physical_plan.cpp index add5518954b..52bed824c46 100644 --- a/dbms/src/Flash/Planner/tests/gtest_physical_plan.cpp +++ b/dbms/src/Flash/Planner/tests/gtest_physical_plan.cpp @@ -13,7 +13,6 @@ // limitations under the License. #include -#include #include #include #include diff --git a/dbms/src/Flash/tests/gtest_aggregation_executor.cpp b/dbms/src/Flash/tests/gtest_aggregation_executor.cpp index 1bf01aa956b..19d02b4a8c5 100644 --- a/dbms/src/Flash/tests/gtest_aggregation_executor.cpp +++ b/dbms/src/Flash/tests/gtest_aggregation_executor.cpp @@ -27,6 +27,7 @@ extern const char force_agg_on_partial_block[]; extern const char force_agg_prefetch[]; extern const char force_agg_two_level_hash_table_before_merge[]; extern const char force_magic_hash[]; +extern const char disable_agg_batch_get_key_holder[]; } // namespace FailPoints namespace tests { @@ -196,8 +197,9 @@ class AggExecutorTestRunner : public ExecutorTest {"key_64_3", TiDB::TP::TypeLongLong, false}, {"key_string_1", TiDB::TP::TypeString, false}, {"key_string_2", TiDB::TP::TypeString, false}, + {"key_nullable_string", TiDB::TP::TypeString, true}, {"key_decimal256", TiDB::TP::TypeString, false}, - {"key_64_nullable", TiDB::TP::TypeLongLong, true}, + {"key_nullable_int64", TiDB::TP::TypeLongLong, true}, {"value", TiDB::TP::TypeLong, false}}; ColumnsWithTypeAndName table_column_data; for (const auto & column_info : mockColumnInfosToTiDBColumnInfos(table_column_infos)) @@ -874,8 +876,14 @@ try // 0: use one level // 1: use two level std::vector two_level_thresholds{0, 1}; - std::vector collators{TiDB::ITiDBCollator::UTF8MB4_BIN, TiDB::ITiDBCollator::UTF8MB4_GENERAL_CI}; + std::vector collators{ + TiDB::ITiDBCollator::UTF8MB4_BIN, + TiDB::ITiDBCollator::BINARY, + TiDB::ITiDBCollator::UTF8_UNICODE_CI, + TiDB::ITiDBCollator::UTF8MB4_GENERAL_CI}; std::vector> group_by_keys{ + /// fast path with one int and one string + {"key_64", "key_nullable_int64", "key_string_1"}, /// fast path with one int and one string {"key_64", "key_string_1"}, /// fast path with two string @@ -884,6 +892,10 @@ try {"key_string_1"}, /// keys need to be shuffled {"key_8", "key_16", "key_32", "key_64"}, + /// test nullable key_serialized(batch-wise) + {"key_nullable_string", "key_nullable_int64", "key_32"}, + /// test nullable key_string(batch-wise) + {"key_nullable_string"}, }; for (auto collator_id : collators) { @@ -906,7 +918,10 @@ try context.context->setSetting( "max_block_size", Field(static_cast(tbl_agg_table_with_special_key_unique_rows * 2))); + // Use non batch way to get reference. + FailPointHelper::enableFailPoint(FailPoints::disable_agg_batch_get_key_holder); auto reference = executeStreams(request); + FailPointHelper::disableFailPoint(FailPoints::disable_agg_batch_get_key_holder); if (current_collator->isCI()) { /// for ci collation, need to sort and compare the result manually @@ -1283,9 +1298,9 @@ try // keys256 {"key_64", "key_64_1", "key_64_2", "key_64_3"}, // nullable_keys128 - {"key_16", "key_64_nullable"}, + {"key_16", "key_nullable_int64"}, // nullable_keys256 - {"key_64", "key_64_nullable"}, + {"key_64", "key_nullable_int64"}, }; for (const auto & keys : group_by_keys) { diff --git a/dbms/src/Flash/tests/gtest_spill_sort.cpp b/dbms/src/Flash/tests/gtest_spill_sort.cpp index 084dc68e71c..d9a61043562 100644 --- a/dbms/src/Flash/tests/gtest_spill_sort.cpp +++ b/dbms/src/Flash/tests/gtest_spill_sort.cpp @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include #include +#include namespace DB { @@ -166,5 +168,58 @@ try } CATCH +TEST_F(SpillSortTestRunner, SpillAfterFilter) +try +{ + DB::MockColumnInfoVec column_infos{ + {"a", TiDB::TP::TypeTiny}, + {"b", TiDB::TP::TypeTiny}, + {"c", TiDB::TP::TypeTiny}, + {"d", TiDB::TP::TypeTiny}, + {"e", TiDB::TP::TypeTiny}}; + ColumnsWithTypeAndName column_data; + size_t table_rows = 102400; + UInt64 max_block_size = 64; + size_t total_data_size = 0; + size_t limit_size = table_rows / 10 * 9; + for (const auto & column_info : mockColumnInfosToTiDBColumnInfos(column_infos)) + { + ColumnGeneratorOpts opts{ + table_rows, + getDataTypeByColumnInfoForComputingLayer(column_info)->getName(), + RANDOM, + column_info.name}; + column_data.push_back(ColumnGenerator::instance().generate(opts)); + total_data_size += column_data.back().column->byteSize(); + } + context.addMockTable("spill_sort_test", "simple_table", column_infos, column_data, 8); + + MockOrderByItemVec order_by_items{ + std::make_pair("a", true), + std::make_pair("b", true), + std::make_pair("c", true), + std::make_pair("d", true), + std::make_pair("e", true)}; + + auto request = context.scan("spill_sort_test", "simple_table") + .filter(gt(col("d"), lit(toField(static_cast(-128))))) + .topN(order_by_items, limit_size) + .project({gt(col("d"), lit(toField(static_cast(-128))))}) + .build(context); + context.context->setSetting("max_block_size", Field(static_cast(max_block_size))); + + /// disable spill + context.context->setSetting("max_bytes_before_external_sort", Field(static_cast(0))); + auto ref_columns = executeStreams(request, 1); + + // The implementation of topN in the pipeline model is LocalSort, and the result of using multiple threads is unstable. Therefore, a single thread is used here instead. + enablePipeline(true); + context.context->setSetting("max_bytes_before_external_sort", Field(static_cast(total_data_size / 10))); + ASSERT_COLUMNS_EQ_R(ref_columns, executeStreams(request, 1)); + context.context->setSetting("max_cached_data_bytes_in_spiller", Field(static_cast(total_data_size / 100))); + ASSERT_COLUMNS_EQ_R(ref_columns, executeStreams(request, 1)); +} +CATCH + } // namespace tests } // namespace DB diff --git a/dbms/src/Functions/CMakeLists.txt b/dbms/src/Functions/CMakeLists.txt index 75ee86d0ecd..ef6d095a328 100644 --- a/dbms/src/Functions/CMakeLists.txt +++ b/dbms/src/Functions/CMakeLists.txt @@ -38,7 +38,3 @@ if (CMAKE_BUILD_TYPE_UC STREQUAL "RELEASE" OR CMAKE_BUILD_TYPE_UC STREQUAL "RELW # Won't generate debug info for files with heavy template instantiation to achieve faster linking and lower size. target_compile_options(tiflash_functions PRIVATE "-g0") endif () - -if (USE_VECTORCLASS) - target_include_directories (tiflash_functions BEFORE PUBLIC ${VECTORCLASS_INCLUDE_DIR}) -endif () diff --git a/dbms/src/Functions/FunctionsMath.h b/dbms/src/Functions/FunctionsMath.h index 0670a5909b3..3ae4b94d3a0 100644 --- a/dbms/src/Functions/FunctionsMath.h +++ b/dbms/src/Functions/FunctionsMath.h @@ -27,28 +27,6 @@ #include #include -/** More efficient implementations of mathematical functions are possible when using a separate library. - * Disabled due to licence compatibility limitations. - * To enable: download http://www.agner.org/optimize/vectorclass.zip and unpack to contrib/vectorclass - * Then rebuild with -DENABLE_VECTORCLASS=1 - */ - -#if USE_VECTORCLASS -#if __clang__ -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wshift-negative-value" -#endif - -#include -#include -#include - -#if __clang__ -#pragma clang diagnostic pop -#endif -#endif - - namespace DB { namespace ErrorCodes @@ -231,30 +209,9 @@ struct UnaryFunctionPlain } }; -#if USE_VECTORCLASS - -template -struct UnaryFunctionVectorized -{ - static constexpr auto name = Name::name; - static constexpr auto rows_per_iteration = 2; - - template - static void execute(const T * src, Float64 * dst) - { - const auto result = Function(Vec2d(src[0], src[1])); - result.store(dst); - } -}; - -#else - #define UnaryFunctionVectorized UnaryFunctionPlain #define UnaryFunctionNullableVectorized UnaryFunctionNullablePlain -#endif - - template class FunctionMathBinaryFloat64 : public IFunction { @@ -637,29 +594,8 @@ struct BinaryFunctionPlain } }; -#if USE_VECTORCLASS - -template -struct BinaryFunctionVectorized -{ - static constexpr auto name = Name::name; - static constexpr auto rows_per_iteration = 2; - - template - static void execute(const T1 * src_left, const T2 * src_right, Float64 * dst) - { - const auto result = Function(Vec2d(src_left[0], src_left[1]), Vec2d(src_right[0], src_right[1])); - result.store(dst); - } -}; - -#else - #define BinaryFunctionVectorized BinaryFunctionPlain -#endif - - struct EImpl { static constexpr auto name = "e"; @@ -830,14 +766,7 @@ using FunctionExp10 = FunctionMathUnaryFloat64>; using FunctionSqrt = FunctionMathUnaryFloat64Nullable>; -using FunctionCbrt = FunctionMathUnaryFloat64::pow -#else - cbrt -#endif - >>; +using FunctionCbrt = FunctionMathUnaryFloat64>; using FunctionSin = FunctionMathUnaryFloat64>; using FunctionCos = FunctionMathUnaryFloat64>; diff --git a/dbms/src/Interpreters/Aggregator.cpp b/dbms/src/Interpreters/Aggregator.cpp index 09e4ccaac49..871eef654d2 100644 --- a/dbms/src/Interpreters/Aggregator.cpp +++ b/dbms/src/Interpreters/Aggregator.cpp @@ -44,6 +44,7 @@ extern const char random_aggregate_merge_failpoint[]; extern const char force_agg_on_partial_block[]; extern const char random_fail_in_resize_callback[]; extern const char force_agg_prefetch[]; +extern const char disable_agg_batch_get_key_holder[]; extern const char force_magic_hash[]; } // namespace FailPoints @@ -388,110 +389,6 @@ enum class AggFastPathType #undef M }; -AggregatedDataVariants::Type ChooseAggregationMethodTwoKeys(const AggFastPathType * fast_path_types) -{ - auto tp1 = fast_path_types[0]; - auto tp2 = fast_path_types[1]; - switch (tp1) - { - case AggFastPathType::Number64: - { - switch (tp2) - { - case AggFastPathType::Number64: - return AggregatedDataVariants::Type::serialized; // unreachable. keys64 or keys128 will be used before - case AggFastPathType::StringBin: - return AggregatedDataVariants::Type::two_keys_num64_strbin; - case AggFastPathType::StringBinPadding: - return AggregatedDataVariants::Type::two_keys_num64_strbinpadding; - } - } - case AggFastPathType::StringBin: - { - switch (tp2) - { - case AggFastPathType::Number64: - return AggregatedDataVariants::Type::two_keys_strbin_num64; - case AggFastPathType::StringBin: - return AggregatedDataVariants::Type::two_keys_strbin_strbin; - case AggFastPathType::StringBinPadding: - return AggregatedDataVariants::Type::serialized; // rare case - } - } - case AggFastPathType::StringBinPadding: - { - switch (tp2) - { - case AggFastPathType::Number64: - return AggregatedDataVariants::Type::two_keys_strbinpadding_num64; - case AggFastPathType::StringBin: - return AggregatedDataVariants::Type::serialized; // rare case - case AggFastPathType::StringBinPadding: - return AggregatedDataVariants::Type::two_keys_strbinpadding_strbinpadding; - } - } - } -} - -// return AggregatedDataVariants::Type::serialized if can NOT determine fast path. -AggregatedDataVariants::Type ChooseAggregationMethodFastPath( - size_t keys_size, - const DataTypes & types_not_null, - const TiDB::TiDBCollators & collators) -{ - std::array fast_path_types{}; - - if (keys_size == fast_path_types.max_size()) - { - for (size_t i = 0; i < keys_size; ++i) - { - const auto & type = types_not_null[i]; - if (type->isString()) - { - if (collators.empty() || !collators[i]) - { - // use original way - return AggregatedDataVariants::Type::serialized; - } - else - { - switch (collators[i]->getCollatorType()) - { - case TiDB::ITiDBCollator::CollatorType::UTF8MB4_BIN: - case TiDB::ITiDBCollator::CollatorType::UTF8_BIN: - case TiDB::ITiDBCollator::CollatorType::LATIN1_BIN: - case TiDB::ITiDBCollator::CollatorType::ASCII_BIN: - { - fast_path_types[i] = AggFastPathType::StringBinPadding; - break; - } - case TiDB::ITiDBCollator::CollatorType::BINARY: - { - fast_path_types[i] = AggFastPathType::StringBin; - break; - } - default: - { - // for CI COLLATION, use original way - return AggregatedDataVariants::Type::serialized; - } - } - } - } - else if (IsTypeNumber64(type)) - { - fast_path_types[i] = AggFastPathType::Number64; - } - else - { - return AggregatedDataVariants::Type::serialized; - } - } - return ChooseAggregationMethodTwoKeys(fast_path_types.data()); - } - return AggregatedDataVariants::Type::serialized; -} - AggregatedDataVariants::Type Aggregator::chooseAggregationMethod() { auto method = chooseAggregationMethodInner(); @@ -628,40 +525,12 @@ AggregatedDataVariants::Type Aggregator::chooseAggregationMethodInner() /// If single string key - will use hash table with references to it. Strings itself are stored separately in Arena. if (params.keys_size == 1 && types_not_null[0]->isString()) - { - if (params.collators.empty() || !params.collators[0]) - { - // use original way. `Type::one_key_strbin` will generate empty column. - return AggregatedDataVariants::Type::key_string; - } - else - { - switch (params.collators[0]->getCollatorType()) - { - case TiDB::ITiDBCollator::CollatorType::UTF8MB4_BIN: - case TiDB::ITiDBCollator::CollatorType::UTF8_BIN: - case TiDB::ITiDBCollator::CollatorType::LATIN1_BIN: - case TiDB::ITiDBCollator::CollatorType::ASCII_BIN: - { - return AggregatedDataVariants::Type::one_key_strbinpadding; - } - case TiDB::ITiDBCollator::CollatorType::BINARY: - { - return AggregatedDataVariants::Type::one_key_strbin; - } - default: - { - // for CI COLLATION, use original way - return AggregatedDataVariants::Type::key_string; - } - } - } - } + return AggregatedDataVariants::Type::key_string; if (params.keys_size == 1 && types_not_null[0]->isFixedString()) return AggregatedDataVariants::Type::key_fixed_string; - return ChooseAggregationMethodFastPath(params.keys_size, types_not_null, params.collators); + return AggregatedDataVariants::Type::serialized; } @@ -688,7 +557,6 @@ void Aggregator::createAggregateStates(AggregateDataPtr & aggregate_data) const } } - /** It's interesting - if you remove `noinline`, then gcc for some reason will inline this function, and the performance decreases (~ 10%). * (Probably because after the inline of this function, more internal functions no longer be inlined.) * Inline does not make sense, since the inner loop is entirely inside this function. @@ -696,52 +564,147 @@ void Aggregator::createAggregateStates(AggregateDataPtr & aggregate_data) const template void NO_INLINE Aggregator::executeImpl( Method & method, - Arena * aggregates_pool, + AggregatedDataVariants & result, AggProcessInfo & agg_process_info, TiDB::TiDBCollators & collators) const { - typename Method::State state(agg_process_info.key_columns, key_sizes, collators); - // 2MB as prefetch threshold, because normally server L2 cache is 1MB. static constexpr size_t prefetch_threshold = (2 << 20); #ifndef NDEBUG + // In debug mode, failpoint disable_agg_batch_get_key_holder can be used. bool disable_prefetch = (method.data.getBufferSizeInBytes() < prefetch_threshold); fiu_do_on(FailPoints::force_agg_prefetch, { disable_prefetch = false; }); + + bool disable_batch_get_key_holder = false; + fiu_do_on(FailPoints::disable_agg_batch_get_key_holder, { disable_batch_get_key_holder = true; }); + + if (disable_batch_get_key_holder) + { + if (disable_prefetch) + executeImplInner< + collect_hit_rate, + only_lookup, + /*enable_prefetch=*/false, + /*enable_agg_batch_get_key_holder=*/false>(method, result, agg_process_info, collators); + else + executeImplInner< + collect_hit_rate, + only_lookup, + /*enable_prefetch=*/true, + /*enable_agg_batch_get_key_holder=*/false>(method, result, agg_process_info, collators); + } + else + { + if (disable_prefetch) + executeImplInner< + collect_hit_rate, + only_lookup, + /*enable_prefetch=*/false, + /*enable_agg_batch_get_key_holder=*/true>(method, result, agg_process_info, collators); + else + executeImplInner< + collect_hit_rate, + only_lookup, + /*enable_prefetch=*/true, + /*enable_agg_batch_get_key_holder=*/true>(method, result, agg_process_info, collators); + } #else const bool disable_prefetch = (method.data.getBufferSizeInBytes() < prefetch_threshold); + if (disable_prefetch) + executeImplInner< + collect_hit_rate, + only_lookup, + /*enable_prefetch=*/false, + /*enable_agg_batch_get_key_holder=*/true>(method, result, agg_process_info, collators); + else + executeImplInner< + collect_hit_rate, + only_lookup, + /*enable_prefetch=*/true, + /*enable_agg_batch_get_key_holder=*/true>(method, result, agg_process_info, collators); #endif +} - if constexpr (Method::State::is_serialized_key) +template < + bool collect_hit_rate, + bool only_lookup, + bool enable_prefetch, + bool enable_batch_get_key_holder, + typename Method> +void Aggregator::executeImplInner( + Method & method, + AggregatedDataVariants & result, + AggProcessInfo & agg_process_info, + TiDB::TiDBCollators & collators) const +{ + auto * aggregates_pool = result.aggregates_pool; + typename Method::State state(agg_process_info.key_columns, key_sizes, collators); + + // For key_serialized, memory allocation and key serialization will be batch-wise. + // For key_string, collation decode will be batch-wise. + static constexpr bool batch_get_key_holder = Method::State::can_batch_get_key_holder && enable_batch_get_key_holder; + if constexpr (batch_get_key_holder) { - executeImplBatch(method, state, aggregates_pool, agg_process_info); + state.initBatchHandler(agg_process_info.start_row, agg_mini_batch); + result.batch_get_key_holder = true; } - else if constexpr (Method::Data::is_string_hash_map) + using KeyHolderType = typename std::conditional< + batch_get_key_holder, + typename Method::State::BatchKeyHolderType, + typename Method::State::KeyHolderType>::type; + + if constexpr (Method::Data::is_string_hash_map) { // StringHashMap doesn't support prefetch. - executeImplBatch(method, state, aggregates_pool, agg_process_info); + executeImplBatch< + collect_hit_rate, + only_lookup, + /*enable_prefetch=*/false, + batch_get_key_holder, + KeyHolderType>(method, state, aggregates_pool, agg_process_info); } else { - if (disable_prefetch) - executeImplBatch(method, state, aggregates_pool, agg_process_info); - else - executeImplBatch(method, state, aggregates_pool, agg_process_info); + executeImplBatch( + method, + state, + aggregates_pool, + agg_process_info); } } -template +template std::optional::ResultType> Aggregator::emplaceOrFindKey( Method & method, typename Method::State & state, - typename Method::State::Derived::KeyHolderType && key_holder, + KeyHolderType & key_holder, size_t hashval) const { try { if constexpr (only_lookup) - return state.template findKey(method.data, std::move(key_holder), hashval); + return state.template findKey(method.data, key_holder, hashval); else - return state.template emplaceKey(method.data, std::move(key_holder), hashval); + return state.template emplaceKey(method.data, key_holder, hashval); + } + catch (ResizeException &) + { + return {}; + } +} + +template +std::optional::ResultType> Aggregator::emplaceOrFindKey( + Method & method, + typename Method::State & state, + KeyHolderType & key_holder) const +{ + try + { + if constexpr (only_lookup) + return state.template findKey(method.data, key_holder); + else + return state.template emplaceKey(method.data, key_holder); } catch (ResizeException &) { @@ -770,30 +733,44 @@ std::optional::Res } } -template -ALWAYS_INLINE inline void prepareBatch( +template +ALWAYS_INLINE inline void setupKeyHolderAndHashVal( size_t row_idx, - size_t end_row, + size_t batch_size, std::vector & hashvals, - std::vector & key_holders, + std::vector & key_holders, Arena * aggregates_pool, std::vector & sort_key_containers, Method & method, typename Method::State & state) { - assert(hashvals.size() == key_holders.size()); + key_holders.resize(batch_size); + if constexpr (enable_prefetch) + hashvals.resize(batch_size); - for (size_t i = row_idx, j = 0; i < row_idx + hashvals.size() && i < end_row; ++i, ++j) + for (size_t i = row_idx, j = 0; i < row_idx + batch_size; ++i, ++j) { - key_holders[j] = static_cast(&state)->getKeyHolder( - i, - aggregates_pool, - sort_key_containers); - hashvals[j] = method.data.hash(keyHolderGetKey(key_holders[j])); + if constexpr (batch_get_key_holder) + key_holders[j] + = static_cast(&state)->getKeyHolderBatch(j, aggregates_pool); + else + key_holders[j] = static_cast(&state)->getKeyHolder( + i, + aggregates_pool, + sort_key_containers); + + if constexpr (enable_prefetch) + hashvals[j] = method.data.hash(keyHolderGetKey(key_holders[j])); } } -template +template < + bool collect_hit_rate, + bool only_lookup, + bool enable_prefetch, + bool batch_get_key_holder, + typename KeyHolderType, + typename Method> ALWAYS_INLINE void Aggregator::executeImplBatch( Method & method, typename Method::State & state, @@ -805,11 +782,13 @@ ALWAYS_INLINE void Aggregator::executeImplBatch( /// Optimization for special case when there are no aggregate functions. if (params.aggregates_size == 0) - return handleOneBatch( - method, - state, - agg_process_info, - aggregates_pool); + return handleOneBatch< + collect_hit_rate, + only_lookup, + enable_prefetch, + batch_get_key_holder, + /*compute_agg_data=*/false, + KeyHolderType>(method, state, agg_process_info, aggregates_pool); /// Optimization for special case when aggregating by 8bit key. if constexpr (std::is_same_v) @@ -851,14 +830,23 @@ ALWAYS_INLINE void Aggregator::executeImplBatch( } /// Generic case. - return handleOneBatch( - method, - state, - agg_process_info, - aggregates_pool); -} - -template + return handleOneBatch< + collect_hit_rate, + only_lookup, + enable_prefetch, + batch_get_key_holder, + /*compute_agg_data=*/true, + KeyHolderType>(method, state, agg_process_info, aggregates_pool); +} + +template < + bool collect_hit_rate, + bool only_lookup, + bool enable_prefetch, + bool batch_get_key_holder, + bool compute_agg_data, + typename KeyHolderType, + typename Method> void Aggregator::handleOneBatch( Method & method, typename Method::State & state, @@ -884,28 +872,38 @@ void Aggregator::handleOneBatch( size_t i = agg_process_info.start_row; const size_t end = agg_process_info.start_row + rows; + Arena temp_batch_pool; size_t mini_batch_size = rows; std::vector hashvals; - std::vector key_holders; - if constexpr (enable_prefetch) + std::vector key_holders; + if constexpr (enable_prefetch || batch_get_key_holder) { - // mini batch will only be used when HashTable is big(a.k.a enable_prefetch is true), - // which can reduce cache miss of agg data. + // mini batch will only be used when HashTable is big(because reduce cache miss of agg data), + // or when need to get key batch-wise. mini_batch_size = agg_mini_batch; - hashvals.resize(agg_mini_batch); - key_holders.resize(agg_mini_batch); } // i is the begin row index of each mini batch. while (i < end) { - if constexpr (enable_prefetch) - { - if unlikely (i + mini_batch_size > end) - mini_batch_size = end - i; + if unlikely (i + mini_batch_size > end) + mini_batch_size = end - i; + + if constexpr (batch_get_key_holder) + state.prepareNextBatch(&temp_batch_pool, mini_batch_size); - prepareBatch(i, end, hashvals, key_holders, aggregates_pool, sort_key_containers, method, state); + if constexpr (enable_prefetch || batch_get_key_holder) + { + setupKeyHolderAndHashVal( + i, + mini_batch_size, + hashvals, + key_holders, + aggregates_pool, + sort_key_containers, + method, + state); } const auto cur_batch_end = i + mini_batch_size; @@ -920,8 +918,11 @@ void Aggregator::handleOneBatch( if likely (k + agg_prefetch_step < hashvals.size()) method.data.prefetch(hashvals[k + agg_prefetch_step]); - emplace_result_holder - = emplaceOrFindKey(method, state, std::move(key_holders[k]), hashvals[k]); + emplace_result_holder = emplaceOrFindKey(method, state, key_holders[k], hashvals[k]); + } + else if constexpr (batch_get_key_holder) + { + emplace_result_holder = emplaceOrFindKey(method, state, key_holders[k]); } else { @@ -941,13 +942,9 @@ void Aggregator::handleOneBatch( if constexpr (compute_agg_data) { if (emplace_result.isFound()) - { aggregate_data = emplace_result.getMapped(); - } else - { agg_process_info.not_found_rows.push_back(j); - } } else { @@ -991,6 +988,9 @@ void Aggregator::handleOneBatch( processed_rows = j; } + if constexpr (batch_get_key_holder) + temp_batch_pool.rollback(); + if unlikely (!processed_rows.has_value()) break; @@ -1194,7 +1194,7 @@ bool Aggregator::executeOnBlockImpl( { \ executeImpl( \ *ToAggregationMethodPtr(NAME, result.aggregation_method_impl), \ - result.aggregates_pool, \ + result, \ agg_process_info, \ params.collators); \ break; \ @@ -1301,24 +1301,52 @@ Block Aggregator::convertOneBucketToBlock( bool final, size_t bucket) const { -#define FILLER_DEFINE(name, skip_convert_key) \ - auto filler_##name = [bucket, &method, arena, this]( \ - const Sizes & key_sizes, \ - MutableColumns & key_columns, \ - AggregateColumnsData & aggregate_columns, \ - MutableColumns & final_aggregate_columns, \ - bool final_) { \ - using METHOD_TYPE = std::decay_t; \ - using DATA_TYPE = std::decay_t; \ - convertToBlockImpl( \ - method, \ - method.data.impls[bucket], \ - key_sizes, \ - key_columns, \ - aggregate_columns, \ - final_aggregate_columns, \ - arena, \ - final_); \ + const bool batch_get_key_holder = data_variants.batch_get_key_holder; +#define FILLER_DEFINE(name, skip_convert_key) \ + auto filler_##name = [bucket, &method, arena, this, batch_get_key_holder]( \ + const Sizes & key_sizes, \ + MutableColumns & key_columns, \ + AggregateColumnsData & aggregate_columns, \ + MutableColumns & final_aggregate_columns, \ + bool final_) { \ + (void)batch_get_key_holder; \ + using METHOD_TYPE = std::decay_t; \ + using DATA_TYPE = std::decay_t; \ + if constexpr (METHOD_TYPE::State::is_serialized_key && METHOD_TYPE::State::can_batch_get_key_holder) \ + { \ + if (batch_get_key_holder) \ + convertToBlockImpl( \ + method, \ + method.data.impls[bucket], \ + key_sizes, \ + key_columns, \ + aggregate_columns, \ + final_aggregate_columns, \ + arena, \ + final_); \ + else \ + convertToBlockImpl( \ + method, \ + method.data.impls[bucket], \ + key_sizes, \ + key_columns, \ + aggregate_columns, \ + final_aggregate_columns, \ + arena, \ + final_); \ + } \ + else \ + { \ + convertToBlockImpl( \ + method, \ + method.data.impls[bucket], \ + key_sizes, \ + key_columns, \ + aggregate_columns, \ + final_aggregate_columns, \ + arena, \ + final_); \ + } \ } FILLER_DEFINE(convert_key, false); @@ -1361,22 +1389,50 @@ BlocksList Aggregator::convertOneBucketToBlocks( bool final, size_t bucket) const { -#define FILLER_DEFINE(name, skip_convert_key) \ - auto filler_##name = [bucket, &method, arena, this]( \ - const Sizes & key_sizes, \ - std::vector & key_columns_vec, \ - std::vector & aggregate_columns_vec, \ - std::vector & final_aggregate_columns_vec, \ - bool final_) { \ - convertToBlocksImpl( \ - method, \ - method.data.impls[bucket], \ - key_sizes, \ - key_columns_vec, \ - aggregate_columns_vec, \ - final_aggregate_columns_vec, \ - arena, \ - final_); \ + const auto batch_get_key_holder = data_variants.batch_get_key_holder; +#define FILLER_DEFINE(name, skip_convert_key) \ + auto filler_##name = [bucket, &method, arena, this, batch_get_key_holder]( \ + const Sizes & key_sizes, \ + std::vector & key_columns_vec, \ + std::vector & aggregate_columns_vec, \ + std::vector & final_aggregate_columns_vec, \ + bool final_) { \ + (void)batch_get_key_holder; \ + if constexpr (Method::State::is_serialized_key && Method::State::can_batch_get_key_holder) \ + { \ + if (batch_get_key_holder) \ + convertToBlocksImpl( \ + method, \ + method.data.impls[bucket], \ + key_sizes, \ + key_columns_vec, \ + aggregate_columns_vec, \ + final_aggregate_columns_vec, \ + arena, \ + final_); \ + else \ + convertToBlocksImpl( \ + method, \ + method.data.impls[bucket], \ + key_sizes, \ + key_columns_vec, \ + aggregate_columns_vec, \ + final_aggregate_columns_vec, \ + arena, \ + final_); \ + } \ + else \ + { \ + convertToBlocksImpl( \ + method, \ + method.data.impls[bucket], \ + key_sizes, \ + key_columns_vec, \ + aggregate_columns_vec, \ + final_aggregate_columns_vec, \ + arena, \ + final_); \ + } \ }; FILLER_DEFINE(convert_key, false); @@ -1502,7 +1558,7 @@ void Aggregator::execute(const BlockInputStreamPtr & stream, AggregatedDataVaria src_bytes / elapsed_seconds / 1048576.0); } -template +template void Aggregator::convertToBlockImpl( Method & method, Table & data, @@ -1522,7 +1578,7 @@ void Aggregator::convertToBlockImpl( raw_key_columns.push_back(column.get()); if (final) - convertToBlockImplFinal( + convertToBlockImplFinal( method, data, key_sizes, @@ -1530,7 +1586,7 @@ void Aggregator::convertToBlockImpl( final_aggregate_columns, arena); else - convertToBlockImplNotFinal( + convertToBlockImplNotFinal( method, data, key_sizes, @@ -1541,7 +1597,7 @@ void Aggregator::convertToBlockImpl( data.clearAndShrink(); } -template +template void Aggregator::convertToBlocksImpl( Method & method, Table & data, @@ -1570,7 +1626,7 @@ void Aggregator::convertToBlocksImpl( } if (final) - convertToBlocksImplFinal( + convertToBlocksImplFinal( method, data, key_sizes, @@ -1578,7 +1634,7 @@ void Aggregator::convertToBlocksImpl( final_aggregate_columns_vec, arena); else - convertToBlocksImplNotFinal( + convertToBlocksImplNotFinal( method, data, key_sizes, @@ -1656,102 +1712,7 @@ inline void Aggregator::insertAggregatesIntoColumns( std::rethrow_exception(exception); } -template -struct AggregatorMethodInitKeyColumnHelper -{ - Method & method; - explicit AggregatorMethodInitKeyColumnHelper(Method & method_) - : method(method_) - {} - ALWAYS_INLINE inline void initAggKeys(size_t, std::vector &) {} - template - ALWAYS_INLINE inline void insertKeyIntoColumns( - const Key & key, - std::vector & key_columns, - const Sizes & sizes, - const TiDB::TiDBCollators & collators) - { - method.insertKeyIntoColumns(key, key_columns, sizes, collators); - } -}; - -template -struct AggregatorMethodInitKeyColumnHelper> -{ - using Method = AggregationMethodFastPathTwoKeysNoCache; - size_t index{}; - std::function &, size_t)> insert_key_into_columns_function_ptr{}; - - Method & method; - explicit AggregatorMethodInitKeyColumnHelper(Method & method_) - : method(method_) - {} - - ALWAYS_INLINE inline void initAggKeys(size_t rows, std::vector & key_columns) - { - index = 0; - if (key_columns.size() == 1) - { - Method::template initAggKeys(rows, key_columns[0]); - insert_key_into_columns_function_ptr - = AggregationMethodFastPathTwoKeysNoCache::insertKeyIntoColumnsOneKey; - } - else if (key_columns.size() == 2) - { - Method::template initAggKeys(rows, key_columns[0]); - Method::template initAggKeys(rows, key_columns[1]); - insert_key_into_columns_function_ptr - = AggregationMethodFastPathTwoKeysNoCache::insertKeyIntoColumnsTwoKey; - } - else - { - throw Exception("unexpected key_columns size for AggMethodFastPathTwoKey: {}", key_columns.size()); - } - } - ALWAYS_INLINE inline void insertKeyIntoColumns( - const StringRef & key, - std::vector & key_columns, - const Sizes &, - const TiDB::TiDBCollators &) - { - assert(insert_key_into_columns_function_ptr); - insert_key_into_columns_function_ptr(key, key_columns, index); - ++index; - } -}; - -template -struct AggregatorMethodInitKeyColumnHelper> -{ - using Method = AggregationMethodOneKeyStringNoCache; - size_t index{}; - - Method & method; - explicit AggregatorMethodInitKeyColumnHelper(Method & method_) - : method(method_) - {} - - void initAggKeys(size_t rows, std::vector & key_columns) - { - index = 0; - RUNTIME_CHECK_MSG( - key_columns.size() == 1, - "unexpected key_columns size for AggMethodOneKeyString: {}", - key_columns.size()); - Method::initAggKeys(rows, key_columns[0]); - } - ALWAYS_INLINE inline void insertKeyIntoColumns( - const StringRef & key, - std::vector & key_columns, - const Sizes &, - const TiDB::TiDBCollators &) - { - method.insertKeyIntoColumns(key, key_columns, index); - ++index; - } -}; - -template +template void NO_INLINE Aggregator::convertToBlockImplFinal( Method & method, Table & data, @@ -1762,7 +1723,6 @@ void NO_INLINE Aggregator::convertToBlockImplFinal( { assert(key_sizes.size() == key_columns.size()); Sizes key_sizes_ref = key_sizes; // NOLINT - AggregatorMethodInitKeyColumnHelper agg_keys_helper{method}; if constexpr (!skip_convert_key) { auto shuffled_key_sizes = method.shuffleKeyColumns(key_columns, key_sizes); @@ -1773,18 +1733,36 @@ void NO_INLINE Aggregator::convertToBlockImplFinal( RUNTIME_CHECK(params.key_ref_agg_func.empty()); key_sizes_ref = *shuffled_key_sizes; } - agg_keys_helper.initAggKeys(data.size(), key_columns); } + PaddedPODArray key_places; + if constexpr (batch_deserialize_key) + key_places.reserve(data.size()); + // Doesn't prefetch agg data, because places[data.size()] is needed, which can be very large. data.forEachValue([&](const auto & key [[maybe_unused]], auto & mapped) { if constexpr (!skip_convert_key) { - agg_keys_helper.insertKeyIntoColumns(key, key_columns, key_sizes_ref, params.collators); + if constexpr (batch_deserialize_key) + { + // Assume key is StringRef, because only key_serialize can be here. + static_assert(std::is_same_v, StringRef>); + key_places.push_back(const_cast(key.data)); + } + else + { + method.insertKeyIntoColumns(key, key_columns, key_sizes_ref, params.collators); + } } insertAggregatesIntoColumns(mapped, final_aggregate_columns, arena); }); + + if constexpr (!skip_convert_key && batch_deserialize_key) + { + if (!key_places.empty()) + method.insertKeyIntoColumnsBatch(key_places, key_columns); + } } namespace @@ -1803,28 +1781,9 @@ std::optional shuffleKeyColumnsForKeyColumnsVec( } return shuffled_key_sizes; } -template -std::vector>> initAggKeysForKeyColumnsVec( - Method & method, - std::vector> & key_columns_vec, - size_t max_block_size, - size_t total_row_count) -{ - std::vector>> agg_keys_helpers; - size_t block_row_count = max_block_size; - for (size_t i = 0; i < key_columns_vec.size(); ++i) - { - if (i == key_columns_vec.size() - 1 && total_row_count % block_row_count != 0) - /// update block_row_count for the last block - block_row_count = total_row_count % block_row_count; - agg_keys_helpers.push_back(std::make_unique>(method)); - agg_keys_helpers.back()->initAggKeys(block_row_count, key_columns_vec[i]); - } - return agg_keys_helpers; -} } // namespace -template +template void NO_INLINE Aggregator::convertToBlocksImplFinal( Method & method, Table & data, @@ -1840,7 +1799,6 @@ void NO_INLINE Aggregator::convertToBlocksImplFinal( assert(key_columns.size() == key_sizes.size()); } #endif - std::vector>>> agg_keys_helpers; Sizes key_sizes_ref = key_sizes; // NOLINT if constexpr (!skip_convert_key) { @@ -1850,32 +1808,56 @@ void NO_INLINE Aggregator::convertToBlocksImplFinal( RUNTIME_CHECK(params.key_ref_agg_func.empty()); key_sizes_ref = *shuffled_key_sizes; } - agg_keys_helpers = initAggKeysForKeyColumnsVec(method, key_columns_vec, params.max_block_size, data.size()); } size_t data_index = 0; const auto rows = data.size(); std::unique_ptr places(new AggregateDataPtr[rows]); + PaddedPODArray key_places; + if constexpr (batch_deserialize_key) + key_places.reserve(params.max_block_size); + size_t current_bound = params.max_block_size; size_t key_columns_vec_index = 0; data.forEachValue([&](const auto & key [[maybe_unused]], auto & mapped) { if constexpr (!skip_convert_key) { - agg_keys_helpers[key_columns_vec_index] - ->insertKeyIntoColumns(key, key_columns_vec[key_columns_vec_index], key_sizes_ref, params.collators); + if constexpr (batch_deserialize_key) + { + // Assume key is StringRef, because only key_serialize can be here. + static_assert(std::is_same_v, StringRef>); + key_places.push_back(const_cast(key.data)); + } + else + { + method + .insertKeyIntoColumns(key, key_columns_vec[key_columns_vec_index], key_sizes_ref, params.collators); + } } places[data_index] = mapped; ++data_index; if unlikely (data_index == current_bound) { + if constexpr (!skip_convert_key && batch_deserialize_key) + { + method.insertKeyIntoColumnsBatch(key_places, key_columns_vec[key_columns_vec_index]); + key_places.clear(); + } + ++key_columns_vec_index; current_bound += params.max_block_size; } }); + if constexpr (!skip_convert_key && batch_deserialize_key) + { + if (!key_places.empty()) + method.insertKeyIntoColumnsBatch(key_places, key_columns_vec[key_columns_vec_index]); + } + data_index = 0; current_bound = params.max_block_size; key_columns_vec_index = 0; @@ -1895,7 +1877,7 @@ void NO_INLINE Aggregator::convertToBlocksImplFinal( } } -template +template void NO_INLINE Aggregator::convertToBlockImplNotFinal( Method & method, Table & data, @@ -1904,7 +1886,6 @@ void NO_INLINE Aggregator::convertToBlockImplNotFinal( AggregateColumnsData & aggregate_columns) const { assert(key_sizes.size() == key_columns.size()); - AggregatorMethodInitKeyColumnHelper agg_keys_helper{method}; Sizes key_sizes_ref = key_sizes; // NOLINT if constexpr (!skip_convert_key) { @@ -1914,13 +1895,25 @@ void NO_INLINE Aggregator::convertToBlockImplNotFinal( RUNTIME_CHECK(params.key_ref_agg_func.empty()); key_sizes_ref = *shuffled_key_sizes; } - agg_keys_helper.initAggKeys(data.size(), key_columns); } + PaddedPODArray key_places; + if constexpr (batch_deserialize_key) + key_places.reserve(data.size()); + data.forEachValue([&](const auto & key [[maybe_unused]], auto & mapped) { if constexpr (!skip_convert_key) { - agg_keys_helper.insertKeyIntoColumns(key, key_columns, key_sizes_ref, params.collators); + if constexpr (batch_deserialize_key) + { + // Assume key is StringRef, because only key_serialize can be here. + static_assert(std::is_same_v, StringRef>); + key_places.push_back(const_cast(key.data)); + } + else + { + method.insertKeyIntoColumns(key, key_columns, key_sizes_ref, params.collators); + } } /// reserved, so push_back does not throw exceptions @@ -1929,9 +1922,15 @@ void NO_INLINE Aggregator::convertToBlockImplNotFinal( mapped = nullptr; }); + + if constexpr (!skip_convert_key && batch_deserialize_key) + { + if (!key_places.empty()) + method.insertKeyIntoColumnsBatch(key_places, key_columns); + } } -template +template void NO_INLINE Aggregator::convertToBlocksImplNotFinal( Method & method, Table & data, @@ -1945,7 +1944,6 @@ void NO_INLINE Aggregator::convertToBlocksImplNotFinal( assert(key_sizes.size() == key_columns.size()); } #endif - std::vector>>> agg_keys_helpers; Sizes key_sizes_ref = key_sizes; // NOLINT if constexpr (!skip_convert_key) { @@ -1955,16 +1953,29 @@ void NO_INLINE Aggregator::convertToBlocksImplNotFinal( RUNTIME_CHECK(params.key_ref_agg_func.empty()); key_sizes_ref = shuffled_key_sizes ? *shuffled_key_sizes : key_sizes; } - agg_keys_helpers = initAggKeysForKeyColumnsVec(method, key_columns_vec, params.max_block_size, data.size()); } + PaddedPODArray key_places; + if constexpr (batch_deserialize_key) + key_places.reserve(params.max_block_size); + size_t data_index = 0; + size_t current_bound = params.max_block_size; + size_t key_columns_vec_index = 0; data.forEachValue([&](const auto & key [[maybe_unused]], auto & mapped) { - size_t key_columns_vec_index = data_index / params.max_block_size; if constexpr (!skip_convert_key) { - agg_keys_helpers[key_columns_vec_index] - ->insertKeyIntoColumns(key, key_columns_vec[key_columns_vec_index], key_sizes_ref, params.collators); + if constexpr (batch_deserialize_key) + { + // Assume key is StringRef, because only key_serialize can be here. + static_assert(std::is_same_v, StringRef>); + key_places.push_back(const_cast(key.data)); + } + else + { + method + .insertKeyIntoColumns(key, key_columns_vec[key_columns_vec_index], key_sizes_ref, params.collators); + } } /// reserved, so push_back does not throw exceptions @@ -1973,7 +1984,25 @@ void NO_INLINE Aggregator::convertToBlocksImplNotFinal( ++data_index; mapped = nullptr; + + if unlikely (data_index == current_bound) + { + if constexpr (!skip_convert_key && batch_deserialize_key) + { + method.insertKeyIntoColumnsBatch(key_places, key_columns_vec[key_columns_vec_index]); + key_places.clear(); + } + + ++key_columns_vec_index; + current_bound += params.max_block_size; + } }); + + if constexpr (!skip_convert_key && batch_deserialize_key) + { + if (!key_places.empty()) + method.insertKeyIntoColumnsBatch(key_places, key_columns_vec[++key_columns_vec_index]); + } } template @@ -2181,7 +2210,6 @@ BlocksList Aggregator::prepareBlocksAndFill( return res_list; } - BlocksList Aggregator::prepareBlocksAndFillWithoutKey(AggregatedDataVariants & data_variants, bool final) const { size_t rows = 1; @@ -2223,33 +2251,62 @@ BlocksList Aggregator::prepareBlocksAndFillWithoutKey(AggregatedDataVariants & d BlocksList Aggregator::prepareBlocksAndFillSingleLevel(AggregatedDataVariants & data_variants, bool final) const { size_t rows = data_variants.size(); -#define M(NAME, skip_convert_key) \ - case AggregationMethodType(NAME): \ - { \ - auto & tmp_method = *ToAggregationMethodPtr(NAME, data_variants.aggregation_method_impl); \ - auto & tmp_data = ToAggregationMethodPtr(NAME, data_variants.aggregation_method_impl) -> data; \ - convertToBlocksImpl( \ - tmp_method, \ - tmp_data, \ - key_sizes, \ - key_columns_vec, \ - aggregate_columns_vec, \ - final_aggregate_columns_vec, \ - data_variants.aggregates_pool, \ - final_); \ - break; \ + const bool batch_get_key_holder = data_variants.batch_get_key_holder; +#define M(NAME, skip_convert_key) \ + case AggregationMethodType(NAME): \ + { \ + auto & tmp_method = *ToAggregationMethodPtr(NAME, data_variants.aggregation_method_impl); \ + auto & tmp_data = ToAggregationMethodPtr(NAME, data_variants.aggregation_method_impl) -> data; \ + using MethodType = std::decay_t; \ + if constexpr (MethodType::State::is_serialized_key && MethodType::State::can_batch_get_key_holder) \ + { \ + if (batch_get_key_holder) \ + convertToBlocksImpl( \ + tmp_method, \ + tmp_data, \ + key_sizes, \ + key_columns_vec, \ + aggregate_columns_vec, \ + final_aggregate_columns_vec, \ + data_variants.aggregates_pool, \ + final_); \ + else \ + convertToBlocksImpl( \ + tmp_method, \ + tmp_data, \ + key_sizes, \ + key_columns_vec, \ + aggregate_columns_vec, \ + final_aggregate_columns_vec, \ + data_variants.aggregates_pool, \ + final_); \ + } \ + else \ + { \ + convertToBlocksImpl( \ + tmp_method, \ + tmp_data, \ + key_sizes, \ + key_columns_vec, \ + aggregate_columns_vec, \ + final_aggregate_columns_vec, \ + data_variants.aggregates_pool, \ + final_); \ + } \ + break; \ } #define M_skip_convert_key(NAME) M(NAME, true) #define M_convert_key(NAME) M(NAME, false) #define FILLER_DEFINE(name, M_tmp) \ - auto filler_##name = [&data_variants, this]( \ + auto filler_##name = [&data_variants, this, batch_get_key_holder]( \ const Sizes & key_sizes, \ std::vector & key_columns_vec, \ std::vector & aggregate_columns_vec, \ std::vector & final_aggregate_columns_vec, \ bool final_) { \ + (void)batch_get_key_holder; \ switch (data_variants.type) \ { \ APPLY_FOR_VARIANTS_SINGLE_LEVEL(M_tmp) \ @@ -2439,6 +2496,9 @@ MergingBucketsPtr Aggregator::mergeAndConvertToBlocks( non_empty_data[i]->aggregates_pools.end()); } + for (auto & data : non_empty_data) + RUNTIME_CHECK(non_empty_data[0]->batch_get_key_holder == data->batch_get_key_holder); + // for single level merge, concurrency must be 1. size_t merge_concurrency = has_at_least_one_two_level ? std::max(max_threads, 1) : 1; return std::make_shared(*this, non_empty_data, final, merge_concurrency); diff --git a/dbms/src/Interpreters/Aggregator.h b/dbms/src/Interpreters/Aggregator.h index fe8e9b855df..6c47dafd8b0 100644 --- a/dbms/src/Interpreters/Aggregator.h +++ b/dbms/src/Interpreters/Aggregator.h @@ -52,7 +52,6 @@ class IBlockOutputStream; template class AggHashTableToBlocksBlockInputStream; - /** Different data structures that can be used for aggregation * For efficiency, the aggregation data itself is put into the pool. * Data and pool ownership (states of aggregate functions) @@ -173,7 +172,7 @@ struct AggregationMethodOneNumber }; /// For the case where there is one string key. -template +template struct AggregationMethodString { using Data = TData; @@ -189,55 +188,7 @@ struct AggregationMethodString : data(other.data) {} - using State = ColumnsHashing::HashMethodString; - template - struct EmplaceOrFindKeyResult - { - }; - - template <> - struct EmplaceOrFindKeyResult - { - using ResultType = ColumnsHashing::columns_hashing_impl::EmplaceResultImpl; - }; - - template <> - struct EmplaceOrFindKeyResult - { - using ResultType = ColumnsHashing::columns_hashing_impl::FindResultImpl; - }; - - static bool canUseKeyRefAggFuncOptimization() { return true; } - std::optional shuffleKeyColumns(std::vector &, const Sizes &) { return {}; } - - static void insertKeyIntoColumns( - const StringRef & key, - std::vector & key_columns, - const Sizes &, - const TiDB::TiDBCollators &) - { - static_cast(key_columns[0])->insertData(key.data, key.size); - } -}; - -/// Same as above but without cache -template -struct AggregationMethodStringNoCache -{ - using Data = TData; - using Key = typename Data::key_type; - using Mapped = typename Data::mapped_type; - - Data data; - - AggregationMethodStringNoCache() = default; - - template - explicit AggregationMethodStringNoCache(const Other & other) - : data(other.data) - {} - - using State = ColumnsHashing::HashMethodString; + using State = ColumnsHashing::HashMethodString; template struct EmplaceOrFindKeyResult { @@ -269,207 +220,8 @@ struct AggregationMethodStringNoCache } }; -template -struct AggregationMethodOneKeyStringNoCache -{ - using Data = TData; - using Key = typename Data::key_type; - using Mapped = typename Data::mapped_type; - - Data data; - - AggregationMethodOneKeyStringNoCache() = default; - - template - explicit AggregationMethodOneKeyStringNoCache(const Other & other) - : data(other.data) - {} - - using State = ColumnsHashing::HashMethodStringBin; - template - struct EmplaceOrFindKeyResult - { - }; - - template <> - struct EmplaceOrFindKeyResult - { - using ResultType = ColumnsHashing::columns_hashing_impl::EmplaceResultImpl; - }; - - template <> - struct EmplaceOrFindKeyResult - { - using ResultType = ColumnsHashing::columns_hashing_impl::FindResultImpl; - }; - - static bool canUseKeyRefAggFuncOptimization() { return true; } - std::optional shuffleKeyColumns(std::vector &, const Sizes &) { return {}; } - - ALWAYS_INLINE static inline void insertKeyIntoColumns( - const StringRef & key, - std::vector & key_columns, - size_t) - { - /// still need to insert data to key because spill may will use this - static_cast(key_columns[0])->insertData(key.data, key.size); - } - ALWAYS_INLINE static inline void initAggKeys(size_t, IColumn *) {} -}; - -/* -/// Same as above but without cache -template -struct AggregationMethodMultiStringNoCache -{ - using Data = TData; - using Key = typename Data::key_type; - using Mapped = typename Data::mapped_type; - - Data data; - - AggregationMethodMultiStringNoCache() = default; - - template - explicit AggregationMethodMultiStringNoCache(const Other & other) - : data(other.data) - {} - - using State = ColumnsHashing::HashMethodMultiString; - using EmplaceResult = ColumnsHashing::columns_hashing_impl::EmplaceResultImpl; - - std::optional shuffleKeyColumns(std::vector &, const Sizes &) { return {}; } - - static void insertKeyIntoColumns(const StringRef & key, std::vector & key_columns, const Sizes &, const TiDB::TiDBCollators &) - { - const auto * pos = key.data; - for (auto & key_column : key_columns) - pos = static_cast(key_column)->deserializeAndInsertFromArena(pos, nullptr); - } -}; -*/ - -template -struct AggregationMethodFastPathTwoKeysNoCache -{ - using Data = TData; - using Key = typename Data::key_type; - using Mapped = typename Data::mapped_type; - - Data data; - - AggregationMethodFastPathTwoKeysNoCache() = default; - - template - explicit AggregationMethodFastPathTwoKeysNoCache(const Other & other) - : data(other.data) - {} - - using State - = ColumnsHashing::HashMethodFastPathTwoKeysSerialized; - template - struct EmplaceOrFindKeyResult - { - }; - - template <> - struct EmplaceOrFindKeyResult - { - using ResultType = ColumnsHashing::columns_hashing_impl::EmplaceResultImpl; - }; - - template <> - struct EmplaceOrFindKeyResult - { - using ResultType = ColumnsHashing::columns_hashing_impl::FindResultImpl; - }; - - static bool canUseKeyRefAggFuncOptimization() { return true; } - std::optional shuffleKeyColumns(std::vector &, const Sizes &) { return {}; } - - template - ALWAYS_INLINE static inline void initAggKeys(size_t rows, IColumn * key_column) - { - auto * column = static_cast(key_column); - column->getData().resize_fill_zero(rows); - } - - ALWAYS_INLINE static inline const char * insertAggKeyIntoColumnString(const char * pos, IColumn * key_column) - { - /// still need to insert data to key because spill may will use this - const size_t string_size = *reinterpret_cast(pos); - pos += sizeof(string_size); - static_cast(key_column)->insertData(pos, string_size); - return pos + string_size; - } - ALWAYS_INLINE static inline void initAggKeyString(size_t, IColumn *) {} - - template <> - ALWAYS_INLINE static inline void initAggKeys(size_t rows, IColumn * key_column) - { - return initAggKeyString(rows, key_column); - } - template <> - ALWAYS_INLINE static inline void initAggKeys( - size_t rows, - IColumn * key_column) - { - return initAggKeyString(rows, key_column); - } - - template - ALWAYS_INLINE static inline const char * insertAggKeyIntoColumn( - const char * pos, - IColumn * key_column, - size_t index) - { - auto * column = static_cast(key_column); - column->getElement(index) = *reinterpret_cast(pos); - return pos + KeyType::ElementSize; - } - template <> - ALWAYS_INLINE static inline const char * insertAggKeyIntoColumn( - const char * pos, - IColumn * key_column, - size_t) - { - return insertAggKeyIntoColumnString(pos, key_column); - } - template <> - ALWAYS_INLINE static inline const char * insertAggKeyIntoColumn( - const char * pos, - IColumn * key_column, - size_t) - { - return insertAggKeyIntoColumnString(pos, key_column); - } - - ALWAYS_INLINE static inline void insertKeyIntoColumnsOneKey( - const StringRef & key, - std::vector & key_columns, - size_t index) - { - insertAggKeyIntoColumn(key.data, key_columns[0], index); - } - - ALWAYS_INLINE static inline void insertKeyIntoColumnsTwoKey( - const StringRef & key, - std::vector & key_columns, - size_t index) - { - const auto * pos = key.data; - { - pos = insertAggKeyIntoColumn(pos, key_columns[0], index); - } - { - pos = insertAggKeyIntoColumn(pos, key_columns[1], index); - } - } -}; - - /// For the case where there is one fixed-length string key. -template +template struct AggregationMethodFixedString { using Data = TData; @@ -485,55 +237,7 @@ struct AggregationMethodFixedString : data(other.data) {} - using State = ColumnsHashing::HashMethodFixedString; - template - struct EmplaceOrFindKeyResult - { - }; - - template <> - struct EmplaceOrFindKeyResult - { - using ResultType = ColumnsHashing::columns_hashing_impl::EmplaceResultImpl; - }; - - template <> - struct EmplaceOrFindKeyResult - { - using ResultType = ColumnsHashing::columns_hashing_impl::FindResultImpl; - }; - - static bool canUseKeyRefAggFuncOptimization() { return true; } - std::optional shuffleKeyColumns(std::vector &, const Sizes &) { return {}; } - - static void insertKeyIntoColumns( - const StringRef & key, - std::vector & key_columns, - const Sizes &, - const TiDB::TiDBCollators &) - { - static_cast(key_columns[0])->insertData(key.data, key.size); - } -}; - -/// Same as above but without cache -template -struct AggregationMethodFixedStringNoCache -{ - using Data = TData; - using Key = typename Data::key_type; - using Mapped = typename Data::mapped_type; - - Data data; - - AggregationMethodFixedStringNoCache() = default; - - template - explicit AggregationMethodFixedStringNoCache(const Other & other) - : data(other.data) - {} - - using State = ColumnsHashing::HashMethodFixedString; + using State = ColumnsHashing::HashMethodFixedString; template struct EmplaceOrFindKeyResult { @@ -720,6 +424,12 @@ struct AggregationMethodSerialized for (size_t i = 0; i < key_columns.size(); ++i) pos = key_columns[i]->deserializeAndInsertFromArena(pos, collators.empty() ? nullptr : collators[i]); } + + static void insertKeyIntoColumnsBatch(PaddedPODArray & key_places, std::vector & key_columns) + { + for (auto * key_column : key_columns) + key_column->deserializeForCmpAndInsertFromPos(key_places, false); + } }; @@ -761,17 +471,18 @@ struct AggregatedDataVariants : private boost::noncopyable */ AggregatedDataWithoutKey without_key = nullptr; + // When the group by key is inserted into the HashTable using the batch method, + // this flag is set to true, indicating that subsequent reads of the group by key from the HashTable should use the batch method for deserialization. + // This is done both for better performance and because currently, batch and non-batch methods are not compatible. + bool batch_get_key_holder = false; + using AggregationMethod_key8 = AggregationMethodOneNumber; using AggregationMethod_key16 = AggregationMethodOneNumber; using AggregationMethod_key32 = AggregationMethodOneNumber; using AggregationMethod_key64 = AggregationMethodOneNumber; using AggregationMethod_key_int256 = AggregationMethodOneNumber; - using AggregationMethod_key_string = AggregationMethodStringNoCache; - using AggregationMethod_one_key_strbin - = AggregationMethodOneKeyStringNoCache; - using AggregationMethod_one_key_strbinpadding - = AggregationMethodOneKeyStringNoCache; - using AggregationMethod_key_fixed_string = AggregationMethodFixedStringNoCache; + using AggregationMethod_key_string = AggregationMethodString; + using AggregationMethod_key_fixed_string = AggregationMethodFixedString; using AggregationMethod_keys16 = AggregationMethodKeysFixed; using AggregationMethod_keys32 = AggregationMethodKeysFixed; using AggregationMethod_keys64 = AggregationMethodKeysFixed; @@ -783,21 +494,18 @@ struct AggregatedDataVariants : private boost::noncopyable using AggregationMethod_key_int256_two_level = AggregationMethodOneNumber; using AggregationMethod_key_string_two_level - = AggregationMethodStringNoCache; - using AggregationMethod_one_key_strbin_two_level - = AggregationMethodOneKeyStringNoCache; - using AggregationMethod_one_key_strbinpadding_two_level - = AggregationMethodOneKeyStringNoCache; + = AggregationMethodString; using AggregationMethod_key_fixed_string_two_level - = AggregationMethodFixedStringNoCache; + = AggregationMethodFixedString; using AggregationMethod_keys32_two_level = AggregationMethodKeysFixed; using AggregationMethod_keys64_two_level = AggregationMethodKeysFixed; using AggregationMethod_keys128_two_level = AggregationMethodKeysFixed; using AggregationMethod_keys256_two_level = AggregationMethodKeysFixed; using AggregationMethod_serialized_two_level = AggregationMethodSerialized; using AggregationMethod_key64_hash64 = AggregationMethodOneNumber; - using AggregationMethod_key_string_hash64 = AggregationMethodStringNoCache; - using AggregationMethod_key_fixed_string_hash64 = AggregationMethodFixedString; + using AggregationMethod_key_string_hash64 = AggregationMethodString; + using AggregationMethod_key_fixed_string_hash64 + = AggregationMethodFixedString; using AggregationMethod_keys128_hash64 = AggregationMethodKeysFixed; using AggregationMethod_keys256_hash64 = AggregationMethodKeysFixed; using AggregationMethod_serialized_hash64 = AggregationMethodSerialized; @@ -831,120 +539,50 @@ struct AggregatedDataVariants : private boost::noncopyable using AggregationMethod_nullable_keys256_magic_hash_two_level = AggregationMethodKeysFixed; - // 2 keys - using AggregationMethod_two_keys_num64_strbin = AggregationMethodFastPathTwoKeysNoCache< - ColumnsHashing::KeyDescNumber64, - ColumnsHashing::KeyDescStringBin, - AggregatedDataWithStringKey>; - using AggregationMethod_two_keys_num64_strbinpadding = AggregationMethodFastPathTwoKeysNoCache< - ColumnsHashing::KeyDescNumber64, - ColumnsHashing::KeyDescStringBinPadding, - AggregatedDataWithStringKey>; - using AggregationMethod_two_keys_strbin_num64 = AggregationMethodFastPathTwoKeysNoCache< - ColumnsHashing::KeyDescStringBin, - ColumnsHashing::KeyDescNumber64, - AggregatedDataWithStringKey>; - using AggregationMethod_two_keys_strbin_strbin = AggregationMethodFastPathTwoKeysNoCache< - ColumnsHashing::KeyDescStringBin, - ColumnsHashing::KeyDescStringBin, - AggregatedDataWithStringKey>; - using AggregationMethod_two_keys_strbinpadding_num64 = AggregationMethodFastPathTwoKeysNoCache< - ColumnsHashing::KeyDescStringBinPadding, - ColumnsHashing::KeyDescNumber64, - AggregatedDataWithStringKey>; - using AggregationMethod_two_keys_strbinpadding_strbinpadding = AggregationMethodFastPathTwoKeysNoCache< - ColumnsHashing::KeyDescStringBinPadding, - ColumnsHashing::KeyDescStringBinPadding, - AggregatedDataWithStringKey>; - - using AggregationMethod_two_keys_num64_strbin_two_level = AggregationMethodFastPathTwoKeysNoCache< - ColumnsHashing::KeyDescNumber64, - ColumnsHashing::KeyDescStringBin, - AggregatedDataWithStringKeyTwoLevel>; - using AggregationMethod_two_keys_num64_strbinpadding_two_level = AggregationMethodFastPathTwoKeysNoCache< - ColumnsHashing::KeyDescNumber64, - ColumnsHashing::KeyDescStringBinPadding, - AggregatedDataWithStringKeyTwoLevel>; - using AggregationMethod_two_keys_strbin_num64_two_level = AggregationMethodFastPathTwoKeysNoCache< - ColumnsHashing::KeyDescStringBin, - ColumnsHashing::KeyDescNumber64, - AggregatedDataWithStringKeyTwoLevel>; - using AggregationMethod_two_keys_strbin_strbin_two_level = AggregationMethodFastPathTwoKeysNoCache< - ColumnsHashing::KeyDescStringBin, - ColumnsHashing::KeyDescStringBin, - AggregatedDataWithStringKeyTwoLevel>; - using AggregationMethod_two_keys_strbinpadding_num64_two_level = AggregationMethodFastPathTwoKeysNoCache< - ColumnsHashing::KeyDescStringBinPadding, - ColumnsHashing::KeyDescNumber64, - AggregatedDataWithStringKeyTwoLevel>; - using AggregationMethod_two_keys_strbinpadding_strbinpadding_two_level = AggregationMethodFastPathTwoKeysNoCache< - ColumnsHashing::KeyDescStringBinPadding, - ColumnsHashing::KeyDescStringBinPadding, - AggregatedDataWithStringKeyTwoLevel>; - - // 3 keys - // TODO: use 3 keys if necessary - /// In this and similar macros, the option without_key is not considered. -#define APPLY_FOR_AGGREGATED_VARIANTS(M) \ - M(key8, false) \ - M(key16, false) \ - M(key32, false) \ - M(key64, false) \ - M(key_string, false) \ - M(key_fixed_string, false) \ - M(keys16, false) \ - M(keys32, false) \ - M(keys64, false) \ - M(keys128, false) \ - M(keys256, false) \ - M(key_int256, false) \ - M(serialized, false) \ - M(key64_hash64, false) \ - M(key_string_hash64, false) \ - M(key_fixed_string_hash64, false) \ - M(keys128_hash64, false) \ - M(keys256_hash64, false) \ - M(serialized_hash64, false) \ - M(nullable_keys128, false) \ - M(nullable_keys256, false) \ - M(two_keys_num64_strbin, false) \ - M(two_keys_num64_strbinpadding, false) \ - M(two_keys_strbin_num64, false) \ - M(two_keys_strbin_strbin, false) \ - M(two_keys_strbinpadding_num64, false) \ - M(two_keys_strbinpadding_strbinpadding, false) \ - M(one_key_strbin, false) \ - M(one_key_strbinpadding, false) \ - M(key32_two_level, true) \ - M(key64_two_level, true) \ - M(key_int256_two_level, true) \ - M(key_string_two_level, true) \ - M(key_fixed_string_two_level, true) \ - M(keys32_two_level, true) \ - M(keys64_two_level, true) \ - M(keys128_two_level, true) \ - M(keys256_two_level, true) \ - M(serialized_two_level, true) \ - M(nullable_keys128_two_level, true) \ - M(nullable_keys256_two_level, true) \ - M(two_keys_num64_strbin_two_level, true) \ - M(two_keys_num64_strbinpadding_two_level, true) \ - M(two_keys_strbin_num64_two_level, true) \ - M(two_keys_strbin_strbin_two_level, true) \ - M(two_keys_strbinpadding_num64_two_level, true) \ - M(two_keys_strbinpadding_strbinpadding_two_level, true) \ - M(one_key_strbin_two_level, true) \ - M(one_key_strbinpadding_two_level, true) \ - M(keys128_magic_hash, false) \ - M(keys256_magic_hash, false) \ - M(key_int256_magic_hash, false) \ - M(nullable_keys128_magic_hash, false) \ - M(nullable_keys256_magic_hash, false) \ - M(key_int256_magic_hash_two_level, true) \ - M(keys128_magic_hash_two_level, true) \ - M(keys256_magic_hash_two_level, true) \ - M(nullable_keys128_magic_hash_two_level, true) \ +#define APPLY_FOR_AGGREGATED_VARIANTS(M) \ + M(key8, false) \ + M(key16, false) \ + M(key32, false) \ + M(key64, false) \ + M(key_string, false) \ + M(key_fixed_string, false) \ + M(keys16, false) \ + M(keys32, false) \ + M(keys64, false) \ + M(keys128, false) \ + M(keys256, false) \ + M(key_int256, false) \ + M(serialized, false) \ + M(key64_hash64, false) \ + M(key_string_hash64, false) \ + M(key_fixed_string_hash64, false) \ + M(keys128_hash64, false) \ + M(keys256_hash64, false) \ + M(serialized_hash64, false) \ + M(nullable_keys128, false) \ + M(nullable_keys256, false) \ + M(key32_two_level, true) \ + M(key64_two_level, true) \ + M(key_int256_two_level, true) \ + M(key_string_two_level, true) \ + M(key_fixed_string_two_level, true) \ + M(keys32_two_level, true) \ + M(keys64_two_level, true) \ + M(keys128_two_level, true) \ + M(keys256_two_level, true) \ + M(serialized_two_level, true) \ + M(nullable_keys128_two_level, true) \ + M(nullable_keys256_two_level, true) \ + M(keys128_magic_hash, false) \ + M(keys256_magic_hash, false) \ + M(key_int256_magic_hash, false) \ + M(nullable_keys128_magic_hash, false) \ + M(nullable_keys256_magic_hash, false) \ + M(key_int256_magic_hash_two_level, true) \ + M(keys128_magic_hash_two_level, true) \ + M(keys256_magic_hash_two_level, true) \ + M(nullable_keys128_magic_hash_two_level, true) \ M(nullable_keys256_magic_hash_two_level, true) enum class Type @@ -1091,14 +729,6 @@ struct AggregatedDataVariants : private boost::noncopyable M(serialized) \ M(nullable_keys128) \ M(nullable_keys256) \ - M(two_keys_num64_strbin) \ - M(two_keys_num64_strbinpadding) \ - M(two_keys_strbin_num64) \ - M(two_keys_strbin_strbin) \ - M(two_keys_strbinpadding_num64) \ - M(two_keys_strbinpadding_strbinpadding) \ - M(one_key_strbin) \ - M(one_key_strbinpadding) \ M(key_int256_magic_hash) \ M(keys128_magic_hash) \ M(keys256_magic_hash) \ @@ -1145,31 +775,23 @@ struct AggregatedDataVariants : private boost::noncopyable void setResizeCallbackIfNeeded(size_t thread_num) const; -#define APPLY_FOR_VARIANTS_TWO_LEVEL(M) \ - M(key32_two_level) \ - M(key64_two_level) \ - M(key_int256_two_level) \ - M(key_string_two_level) \ - M(key_fixed_string_two_level) \ - M(keys32_two_level) \ - M(keys64_two_level) \ - M(keys128_two_level) \ - M(keys256_two_level) \ - M(serialized_two_level) \ - M(nullable_keys128_two_level) \ - M(nullable_keys256_two_level) \ - M(two_keys_num64_strbin_two_level) \ - M(two_keys_num64_strbinpadding_two_level) \ - M(two_keys_strbin_num64_two_level) \ - M(two_keys_strbin_strbin_two_level) \ - M(two_keys_strbinpadding_num64_two_level) \ - M(two_keys_strbinpadding_strbinpadding_two_level) \ - M(one_key_strbin_two_level) \ - M(one_key_strbinpadding_two_level) \ - M(key_int256_magic_hash_two_level) \ - M(keys128_magic_hash_two_level) \ - M(keys256_magic_hash_two_level) \ - M(nullable_keys128_magic_hash_two_level) \ +#define APPLY_FOR_VARIANTS_TWO_LEVEL(M) \ + M(key32_two_level) \ + M(key64_two_level) \ + M(key_int256_two_level) \ + M(key_string_two_level) \ + M(key_fixed_string_two_level) \ + M(keys32_two_level) \ + M(keys64_two_level) \ + M(keys128_two_level) \ + M(keys256_two_level) \ + M(serialized_two_level) \ + M(nullable_keys128_two_level) \ + M(nullable_keys256_two_level) \ + M(key_int256_magic_hash_two_level) \ + M(keys128_magic_hash_two_level) \ + M(keys256_magic_hash_two_level) \ + M(nullable_keys128_magic_hash_two_level) \ M(nullable_keys256_magic_hash_two_level) }; @@ -1503,31 +1125,62 @@ class Aggregator template void executeImpl( Method & method, - Arena * aggregates_pool, + AggregatedDataVariants & result, + AggProcessInfo & agg_process_info, + TiDB::TiDBCollators & collators) const; + + template < + bool collect_hit_rate, + bool only_lookup, + bool enable_prefetch, + bool enable_batch_get_key_holder, + typename Method> + void executeImplInner( + Method & method, + AggregatedDataVariants & result, AggProcessInfo & agg_process_info, TiDB::TiDBCollators & collators) const; - template + template < + bool collect_hit_rate, + bool only_loopup, + bool enable_prefetch, + bool batch_get_key_holder, + typename KeyHolderType, + typename Method> void executeImplBatch( Method & method, typename Method::State & state, Arena * aggregates_pool, AggProcessInfo & agg_process_info) const; - template + template < + bool collect_hit_rate, + bool only_lookup, + bool enable_prefetch, + bool batch_get_key_holder, + bool compute_agg_data, + typename KeyHolderType, + typename Method> void handleOneBatch( Method & method, typename Method::State & state, AggProcessInfo & agg_process_info, Arena * aggregates_pool) const; - template + template std::optional::ResultType> emplaceOrFindKey( Method & method, typename Method::State & state, - typename Method::State::Derived::KeyHolderType && key_holder, + KeyHolderType & key_holder, size_t hashval) const; + template + std::optional::ResultType> emplaceOrFindKey( + Method & method, + typename Method::State & state, + KeyHolderType & key_holder) const; + template std::optional::ResultType> emplaceOrFindKey( Method & method, @@ -1552,7 +1205,7 @@ class Aggregator template void mergeSingleLevelDataImpl(ManyAggregatedDataVariants & non_empty_data) const; - template + template void convertToBlockImpl( Method & method, Table & data, @@ -1566,7 +1219,7 @@ class Aggregator // The template parameter skip_convert_key indicates whether we can skip deserializing the keys in the HashMap. // For example, select first_row(c1) from t group by c1, where c1 is a string column with collator, // only the result of first_row(c1) needs to be constructed. The key c1 only needs to reference to first_row(c1). - template + template void convertToBlocksImpl( Method & method, Table & data, @@ -1577,7 +1230,7 @@ class Aggregator Arena * arena, bool final) const; - template + template void convertToBlockImplFinal( Method & method, Table & data, @@ -1586,7 +1239,7 @@ class Aggregator MutableColumns & final_aggregate_columns, Arena * arena) const; - template + template void convertToBlocksImplFinal( Method & method, Table & data, @@ -1595,7 +1248,7 @@ class Aggregator std::vector & final_aggregate_columns_vec, Arena * arena) const; - template + template void convertToBlockImplNotFinal( Method & method, Table & data, @@ -1603,7 +1256,7 @@ class Aggregator std::vector key_columns, AggregateColumnsData & aggregate_columns) const; - template + template void convertToBlocksImplNotFinal( Method & method, Table & data, diff --git a/dbms/src/Interpreters/Context.cpp b/dbms/src/Interpreters/Context.cpp index e84c357ad73..0642f05728e 100644 --- a/dbms/src/Interpreters/Context.cpp +++ b/dbms/src/Interpreters/Context.cpp @@ -77,7 +77,6 @@ #include #include -#include #include @@ -155,7 +154,6 @@ struct ContextShared mutable DM::ColumnCacheLongTermPtr column_cache_long_term; mutable DM::DeltaIndexManagerPtr delta_index_manager; /// Manage the Delta Indies of Segments. ProcessList process_list; /// Executing queries at the moment. - ViewDependencies view_dependencies; /// Current dependencies ConfigurationPtr users_config; /// Config with the users, profiles and quotas sections. BackgroundProcessingPoolPtr background_pool; /// The thread pool for the background work performed by the tables. BackgroundProcessingPoolPtr @@ -755,54 +753,6 @@ void Context::checkDatabaseAccessRightsImpl(const std::string & database_name) c throw Exception(fmt::format("Access denied to database {}", database_name), ErrorCodes::DATABASE_ACCESS_DENIED); } -void Context::addDependency(const DatabaseAndTableName & from, const DatabaseAndTableName & where) -{ - auto lock = getLock(); - checkDatabaseAccessRightsImpl(from.first); - checkDatabaseAccessRightsImpl(where.first); - shared->view_dependencies[from].insert(where); - - // Notify table of dependencies change - auto table = tryGetTable(from.first, from.second); - if (table != nullptr) - table->updateDependencies(); -} - -void Context::removeDependency(const DatabaseAndTableName & from, const DatabaseAndTableName & where) -{ - auto lock = getLock(); - checkDatabaseAccessRightsImpl(from.first); - checkDatabaseAccessRightsImpl(where.first); - shared->view_dependencies[from].erase(where); - - // Notify table of dependencies change - auto table = tryGetTable(from.first, from.second); - if (table != nullptr) - table->updateDependencies(); -} - -Dependencies Context::getDependencies(const String & database_name, const String & table_name) const -{ - auto lock = getLock(); - - String db = resolveDatabase(database_name, current_database); - - if (database_name.empty() && tryGetExternalTable(table_name)) - { - /// Table is temporary. Access granted. - } - else - { - checkDatabaseAccessRightsImpl(db); - } - - auto iter = shared->view_dependencies.find(DatabaseAndTableName(db, table_name)); - if (iter == shared->view_dependencies.end()) - return {}; - - return Dependencies(iter->second.begin(), iter->second.end()); -} - bool Context::isTableExist(const String & database_name, const String & table_name) const { auto lock = getLock(); @@ -822,12 +772,6 @@ bool Context::isDatabaseExist(const String & database_name) const return shared->databases.end() != shared->databases.find(db); } -bool Context::isExternalTableExist(const String & table_name) const -{ - return external_tables.end() != external_tables.find(table_name); -} - - void Context::assertTableExists(const String & database_name, const String & table_name) const { auto lock = getLock(); @@ -891,39 +835,6 @@ void Context::assertDatabaseDoesntExist(const String & database_name) const ErrorCodes::DATABASE_ALREADY_EXISTS); } - -Tables Context::getExternalTables() const -{ - auto lock = getLock(); - - Tables res; - for (const auto & table : external_tables) - res[table.first] = table.second.first; - - if (session_context && session_context != this) - { - Tables buf = session_context->getExternalTables(); - res.insert(buf.begin(), buf.end()); - } - else if (global_context && global_context != this) - { - Tables buf = global_context->getExternalTables(); - res.insert(buf.begin(), buf.end()); - } - return res; -} - - -StoragePtr Context::tryGetExternalTable(const String & table_name) const -{ - auto jt = external_tables.find(table_name); - if (external_tables.end() == jt) - return StoragePtr(); - - return jt->second.first; -} - - StoragePtr Context::getTable(const String & database_name, const String & table_name) const { Exception exc; @@ -944,13 +855,6 @@ StoragePtr Context::getTableImpl(const String & database_name, const String & ta { auto lock = getLock(); - if (database_name.empty()) - { - StoragePtr res = tryGetExternalTable(table_name); - if (res) - return res; - } - String db = resolveDatabase(database_name, current_database); checkDatabaseAccessRightsImpl(db); @@ -977,30 +881,6 @@ StoragePtr Context::getTableImpl(const String & database_name, const String & ta return table; } - -void Context::addExternalTable(const String & table_name, const StoragePtr & storage, const ASTPtr & ast) -{ - if (external_tables.end() != external_tables.find(table_name)) - throw Exception( - fmt::format("Temporary table {} already exists.", backQuoteIfNeed(table_name)), - ErrorCodes::TABLE_ALREADY_EXISTS); - - external_tables[table_name] = std::pair(storage, ast); -} - -StoragePtr Context::tryRemoveExternalTable(const String & table_name) -{ - auto it = external_tables.find(table_name); - - if (external_tables.end() == it) - return StoragePtr(); - - auto storage = it->second.first; - external_tables.erase(it); - return storage; -} - - StoragePtr Context::executeTableFunction(const ASTPtr & table_expression) { /// Slightly suboptimal. @@ -1100,17 +980,6 @@ ASTPtr Context::getCreateTableQuery(const String & database_name, const String & return shared->databases[db]->getCreateTableQuery(*this, table_name); } -ASTPtr Context::getCreateExternalTableQuery(const String & table_name) const -{ - auto jt = external_tables.find(table_name); - if (external_tables.end() == jt) - throw Exception( - fmt::format("Temporary table {} doesn't exist", backQuoteIfNeed(table_name)), - ErrorCodes::UNKNOWN_TABLE); - - return jt->second.second; -} - ASTPtr Context::getCreateDatabaseQuery(const String & database_name) const { auto lock = getLock(); @@ -1147,25 +1016,15 @@ void Context::setSettings(const Settings & settings_) void Context::setSetting(const String & name, const Field & value) { - if (name == "profile") - { - auto lock = getLock(); - settings.setProfile(value.safeGet(), *shared->users_config); - } - else - settings.set(name, value); + assert(name != "profile"); + settings.set(name, value); } void Context::setSetting(const String & name, const std::string & value) { - if (name == "profile") - { - auto lock = getLock(); - settings.setProfile(value, *shared->users_config); - } - else - settings.set(name, value); + assert(name != "profile"); + settings.set(name, value); } @@ -1971,15 +1830,6 @@ SharedContextDisaggPtr Context::getSharedContextDisagg() const return shared->ctx_disagg; } -UInt16 Context::getTCPPort() const -{ - auto lock = getLock(); - - auto & config = getConfigRef(); - return config.getInt("tcp_port"); -} - - void Context::initializeSystemLogs() { auto lock = getLock(); @@ -2070,7 +1920,8 @@ void Context::shutdown() void Context::setDefaultProfiles() { shared->default_profile_name = "default"; - setSetting("profile", shared->default_profile_name); + auto lock = getLock(); + settings.setProfile(shared->default_profile_name, *shared->users_config); is_config_loaded = true; } @@ -2205,12 +2056,12 @@ void Context::initRegionBlocklist(const std::unordered_set & region_id auto lock = getLock(); shared->region_blocklist = region_ids; } -bool Context::isRegionInBlocklist(const RegionID region_id) +bool Context::isRegionInBlocklist(const RegionID region_id) const { auto lock = getLock(); return shared->region_blocklist.count(region_id) > 0; } -bool Context::isRegionsContainsInBlocklist(const std::vector & regions) +bool Context::isRegionsContainsInBlocklist(const std::vector & regions) const { auto lock = getLock(); for (const auto region : regions) diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index 2578fa36767..4f346ee2de8 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -28,7 +28,6 @@ #include #include #include -#include #include #include @@ -123,10 +122,6 @@ using GlobalPageIdAllocatorPtr = std::shared_ptr; /// (database name, table name) using DatabaseAndTableName = std::pair; -/// Table -> set of table-views that make SELECT from it. -using ViewDependencies = std::map>; -using Dependencies = std::vector; - using TableAndCreateAST = std::pair; using TableAndCreateASTs = std::map; @@ -158,7 +153,6 @@ class Context /// Format, used when server formats data by itself and if query does not have FORMAT specification. /// Thus, used in HTTP interface. If not specified - then some globally default format is used. String default_format; - TableAndCreateASTs external_tables; /// Temporary tables. Tables table_function_results; /// Temporary tables obtained by execution of table functions. Keyed by AST tree id. Context * query_context = nullptr; Context * session_context = nullptr; /// Session context or nullptr. Could be equal to this. @@ -223,7 +217,6 @@ class Context void setPath(const String & path); void setTemporaryPath(const String & path); void setFlagsPath(const String & path); - void setUserFilesPath(const String & path); void setPathPool( const Strings & main_data_paths, @@ -238,6 +231,11 @@ class Context void setConfig(const ConfigurationPtr & config); Poco::Util::AbstractConfiguration & getConfigRef() const; + using ConfigReloadCallback = std::function; + void setConfigReloadCallback(ConfigReloadCallback && callback); + void reloadConfig() const; + void reloadDeltaTreeConfig(const Poco::Util::AbstractConfiguration & config); + /** Take the list of users, quotas and configuration profiles from this config. * The list of users is completely replaced. * The accumulated quota values are not reset if the quota is not deleted. @@ -245,9 +243,11 @@ class Context void setUsersConfig(const ConfigurationPtr & config); ConfigurationPtr getUsersConfig(); + /// Sets default_profile, must be called once during the initialization + void setDefaultProfiles(); + /// Security configuration settings. void setSecurityConfig(Poco::Util::AbstractConfiguration & config, const LoggerPtr & log); - TiFlashSecurityConfigPtr getSecurityConfig(); /// Must be called before getClientInfo. @@ -269,14 +269,9 @@ class Context const Poco::Net::IPAddress & address); QuotaForIntervals & getQuota(); - void addDependency(const DatabaseAndTableName & from, const DatabaseAndTableName & where); - void removeDependency(const DatabaseAndTableName & from, const DatabaseAndTableName & where); - Dependencies getDependencies(const String & database_name, const String & table_name) const; - /// Checking the existence of the table/database. Database can be empty - in this case the current database is used. bool isTableExist(const String & database_name, const String & table_name) const; bool isDatabaseExist(const String & database_name) const; - bool isExternalTableExist(const String & table_name) const; void assertTableExists(const String & database_name, const String & table_name) const; /** The parameter check_database_access_rights exists to not check the permissions of the database again, @@ -292,18 +287,18 @@ class Context void assertDatabaseDoesntExist(const String & database_name) const; void checkDatabaseAccessRights(const std::string & database_name) const; - Tables getExternalTables() const; - StoragePtr tryGetExternalTable(const String & table_name) const; StoragePtr getTable(const String & database_name, const String & table_name) const; StoragePtr tryGetTable(const String & database_name, const String & table_name) const; - void addExternalTable(const String & table_name, const StoragePtr & storage, const ASTPtr & ast = {}); - StoragePtr tryRemoveExternalTable(const String & table_name); StoragePtr executeTableFunction(const ASTPtr & table_expression); void addDatabase(const String & database_name, const DatabasePtr & database); DatabasePtr detachDatabase(const String & database_name); + DatabasePtr getDatabase(const String & database_name) const; + DatabasePtr tryGetDatabase(const String & database_name) const; + Databases getDatabases() const; + /// Get an object that protects the table from concurrently executing multiple DDL operations. /// If such an object already exists, an exception is thrown. std::unique_ptr getDDLGuard(const String & table, const String & message) const; @@ -322,11 +317,12 @@ class Context void setDefaultFormat(const String & name); Settings getSettings() const; - void setSettings(const Settings & settings_); + const Settings & getSettingsRef() const; + Settings & getSettingsRef(); + void setSettings(const Settings & settings_); /// Set a setting by name. void setSetting(const String & name, const Field & value); - /// Set a setting by name. Read the value in text form from a string (for example, from a config, or from a URL parameter). void setSetting(const String & name, const std::string & value); @@ -338,18 +334,10 @@ class Context size_t max_block_size) const; BlockOutputStreamPtr getOutputFormat(const String & name, WriteBuffer & buf, const Block & sample) const; - /// The port that the server listens for executing SQL queries. - UInt16 getTCPPort() const; - /// Get query for the CREATE table. ASTPtr getCreateTableQuery(const String & database_name, const String & table_name) const; - ASTPtr getCreateExternalTableQuery(const String & table_name) const; ASTPtr getCreateDatabaseQuery(const String & database_name) const; - DatabasePtr getDatabase(const String & database_name) const; - DatabasePtr tryGetDatabase(const String & database_name) const; - Databases getDatabases() const; - std::shared_ptr acquireSession( const String & session_id, std::chrono::steady_clock::duration timeout, @@ -376,9 +364,6 @@ class Context void setQueryContext(Context & context_) { query_context = &context_; } void setSessionContext(Context & context_) { session_context = &context_; } void setGlobalContext(Context & context_) { global_context = &context_; } - const Settings & getSettingsRef() const; - Settings & getSettingsRef(); - void setProgressCallback(ProgressCallback callback); /// Used in InterpreterSelectQuery to pass it to the IProfilingBlockInputStream. @@ -509,15 +494,8 @@ class Context /// Get the server uptime in seconds. time_t getUptimeSeconds() const; - using ConfigReloadCallback = std::function; - void setConfigReloadCallback(ConfigReloadCallback && callback); - void reloadConfig() const; - void shutdown(); - /// Sets default_profile, must be called once during the initialization - void setDefaultProfiles(); - void setServerInfo(const ServerInfo & server_info); const std::optional & getServerInfo() const; @@ -529,8 +507,6 @@ class Context /// User name and session identifier. Named sessions are local to users. using SessionKey = std::pair; - void reloadDeltaTreeConfig(const Poco::Util::AbstractConfiguration & config); - size_t getMaxStreams() const; /// For executor, MPPTask, CancelMPPTasks tests. @@ -557,8 +533,8 @@ class Context void initKeyspaceBlocklist(const std::unordered_set & keyspace_ids); bool isKeyspaceInBlocklist(KeyspaceID keyspace_id); void initRegionBlocklist(const std::unordered_set & region_ids); - bool isRegionInBlocklist(RegionID region_id); - bool isRegionsContainsInBlocklist(const std::vector & regions); + bool isRegionInBlocklist(RegionID region_id) const; + bool isRegionsContainsInBlocklist(const std::vector & regions) const; bool initializeStoreIdBlockList(const String &); const std::unordered_set * getStoreIdBlockList() const; diff --git a/dbms/src/Interpreters/DictionaryFactory.cpp b/dbms/src/Interpreters/DictionaryFactory.cpp index 6af417319b6..c2bfe2f1977 100644 --- a/dbms/src/Interpreters/DictionaryFactory.cpp +++ b/dbms/src/Interpreters/DictionaryFactory.cpp @@ -13,7 +13,6 @@ // limitations under the License. #include -#include #include #include #include @@ -59,32 +58,7 @@ DictionaryPtr DictionaryFactory::create( const auto & layout_type = keys.front(); - if ("complex_key_cache" == layout_type) - { - if (!dict_struct.key) - throw Exception{ - "'key' is required for dictionary of layout 'complex_key_hashed'", - ErrorCodes::BAD_ARGUMENTS}; - - const auto size = config.getInt(layout_prefix + ".complex_key_cache.size_in_cells"); - if (size == 0) - throw Exception{ - name + ": dictionary of layout 'cache' cannot have 0 cells", - ErrorCodes::TOO_SMALL_BUFFER_SIZE}; - - if (require_nonempty) - throw Exception{ - name + ": dictionary of layout 'cache' cannot have 'require_nonempty' attribute set", - ErrorCodes::BAD_ARGUMENTS}; - - return std::make_unique( - name, - dict_struct, - std::move(source_ptr), - dict_lifetime, - size); - } - else if ("ip_trie" == layout_type) + if ("ip_trie" == layout_type) { if (!dict_struct.key) throw Exception{"'key' is required for dictionary of layout 'ip_trie'", ErrorCodes::BAD_ARGUMENTS}; @@ -129,7 +103,7 @@ DictionaryPtr DictionaryFactory::create( } throw Exception{name + ": unknown dictionary layout type: " + layout_type, ErrorCodes::UNKNOWN_ELEMENT_IN_CONFIG}; -}; +} } // namespace DB diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.cpp b/dbms/src/Interpreters/ExpressionAnalyzer.cpp index 8601df0110b..97bd13d1aa9 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.cpp +++ b/dbms/src/Interpreters/ExpressionAnalyzer.cpp @@ -146,7 +146,7 @@ ExpressionAnalyzer::ExpressionAnalyzer( const NamesAndTypesList & source_columns_, const Names & required_result_columns_, size_t subquery_depth_, - bool do_global_, + bool /*do_global_*/, const SubqueriesForSets & subqueries_for_set_) : ast(ast_) , context(context_) @@ -155,7 +155,6 @@ ExpressionAnalyzer::ExpressionAnalyzer( , source_columns(source_columns_) , required_result_columns(required_result_columns_.begin(), required_result_columns_.end()) , storage(storage_) - , do_global(do_global_) , subqueries_for_sets(subqueries_for_set_) { select_query = typeid_cast(ast.get()); @@ -216,17 +215,7 @@ ExpressionAnalyzer::ExpressionAnalyzer( /// Delete the unnecessary from `source_columns` list. Create `unknown_required_source_columns`. Form `columns_added_by_join`. collectUsedColumns(); - /// external_tables, subqueries_for_sets for global subqueries. - /// Replaces global subqueries with the generated names of temporary tables that will be sent to remote servers. - initGlobalSubqueriesAndExternalTables(); - /// has_aggregation, aggregation_keys, aggregate_descriptions, aggregated_columns. - /// This analysis should be performed after processing global subqueries, because otherwise, - /// if the aggregate function contains a global subquery, then `analyzeAggregation` method will save - /// in `aggregate_descriptions` the information about the parameters of this aggregate function, among which - /// global subquery. Then, when you call `initGlobalSubqueriesAndExternalTables` method, this - /// the global subquery will be replaced with a temporary table, resulting in aggregate_descriptions - /// will contain out-of-date information, which will lead to an error when the query is executed. analyzeAggregation(); } @@ -582,59 +571,6 @@ void ExpressionAnalyzer::analyzeAggregation() } } - -void ExpressionAnalyzer::initGlobalSubqueriesAndExternalTables() -{ - /// Adds existing external tables (not subqueries) to the external_tables dictionary. - findExternalTables(ast); - - /// Converts GLOBAL subqueries to external tables; Puts them into the external_tables dictionary: name -> StoragePtr. - initGlobalSubqueries(ast); -} - - -void ExpressionAnalyzer::initGlobalSubqueries(ASTPtr & ast) -{ - /// Recursive calls. We do not go into subqueries. - - for (auto & child : ast->children) - if (!typeid_cast(child.get())) - initGlobalSubqueries(child); - - /// Bottom-up actions. - - if (auto * node = typeid_cast(ast.get())) - { - /// For GLOBAL IN. - if (do_global && (node->name == "globalIn" || node->name == "globalNotIn")) - addExternalStorage(node->arguments->children.at(1)); - } - else if (auto * node = typeid_cast(ast.get())) - { - /// For GLOBAL JOIN. - if (do_global && node->table_join - && static_cast(*node->table_join).locality == ASTTableJoin::Locality::Global) - addExternalStorage(node->table_expression); - } -} - - -void ExpressionAnalyzer::findExternalTables(ASTPtr & ast) -{ - /// Traverse from the bottom. Intentionally go into subqueries. - for (auto & child : ast->children) - findExternalTables(child); - - /// If table type identifier - StoragePtr external_storage; - - if (auto * node = typeid_cast(ast.get())) - if (node->kind == ASTIdentifier::Table) - if ((external_storage = context.tryGetExternalTable(node->name))) - external_tables[node->name] = external_storage; -} - - static std::pair getDatabaseAndTableNameFromIdentifier(const ASTIdentifier & identifier) { std::pair res; @@ -650,7 +586,6 @@ static std::pair getDatabaseAndTableNameFromIdentifier(const AST return res; } - static std::shared_ptr interpretSubquery( const ASTPtr & subquery_or_table_name, const Context & context, @@ -756,102 +691,6 @@ static std::shared_ptr interpretSubquery( subquery_depth + 1); } - -void ExpressionAnalyzer::addExternalStorage(ASTPtr & subquery_or_table_name_or_table_expression) -{ - /// With nondistributed queries, creating temporary tables does not make sense. - if (!(storage && storage->isRemote())) - return; - - ASTPtr subquery; - ASTPtr table_name; - ASTPtr subquery_or_table_name; - - if (typeid_cast(subquery_or_table_name_or_table_expression.get())) - { - table_name = subquery_or_table_name_or_table_expression; - subquery_or_table_name = table_name; - } - else if ( - const auto * ast_table_expr - = typeid_cast(subquery_or_table_name_or_table_expression.get())) - { - if (ast_table_expr->database_and_table_name) - { - table_name = ast_table_expr->database_and_table_name; - subquery_or_table_name = table_name; - } - else if (ast_table_expr->subquery) - { - subquery = ast_table_expr->subquery; - subquery_or_table_name = subquery; - } - } - else if (typeid_cast(subquery_or_table_name_or_table_expression.get())) - { - subquery = subquery_or_table_name_or_table_expression; - subquery_or_table_name = subquery; - } - - if (!subquery_or_table_name) - throw Exception( - "Logical error: unknown AST element passed to ExpressionAnalyzer::addExternalStorage method", - ErrorCodes::LOGICAL_ERROR); - - if (table_name) - { - /// If this is already an external table, you do not need to add anything. Just remember its presence. - if (external_tables.end() != external_tables.find(static_cast(*table_name).name)) - return; - } - - /// Generate the name for the external table. - String external_table_name = "_data" + toString(external_table_id); - while (external_tables.count(external_table_name)) - { - ++external_table_id; - external_table_name = "_data" + toString(external_table_id); - } - - auto interpreter = interpretSubquery(subquery_or_table_name, context, subquery_depth, {}); - - Block sample = interpreter->getSampleBlock(); - NamesAndTypesList columns = sample.getNamesAndTypesList(); - - StoragePtr external_storage = StorageMemory::create(external_table_name, ColumnsDescription{columns}); - external_storage->startup(); - - /** We replace the subquery with the name of the temporary table. - * It is in this form, the request will go to the remote server. - * This temporary table will go to the remote server, and on its side, - * instead of doing a subquery, you just need to read it. - */ - - auto database_and_table_name = std::make_shared(external_table_name, ASTIdentifier::Table); - - if (auto * ast_table_expr = typeid_cast(subquery_or_table_name_or_table_expression.get())) - { - ast_table_expr->subquery.reset(); - ast_table_expr->database_and_table_name = database_and_table_name; - - ast_table_expr->children.clear(); - ast_table_expr->children.emplace_back(database_and_table_name); - } - else - subquery_or_table_name_or_table_expression = database_and_table_name; - - external_tables[external_table_name] = external_storage; - subqueries_for_sets[external_table_name].source = interpreter->execute().in; - subqueries_for_sets[external_table_name].table = external_storage; - - /** NOTE If it was written IN tmp_table - the existing temporary (but not external) table, - * then a new temporary table will be created (for example, _data1), - * and the data will then be copied to it. - * Maybe this can be avoided. - */ -} - - static NamesAndTypesList::iterator findColumn(const String & name, NamesAndTypesList & cols) { return std::find_if(cols.begin(), cols.end(), [&](const NamesAndTypesList::value_type & val) { @@ -1127,8 +966,6 @@ void ExpressionAnalyzer::normalizeTreeImpl( /// If the WHERE clause or HAVING consists of a single alias, the reference must be replaced not only in children, but also in where_expression and having_expression. if (auto * select = typeid_cast(ast.get())) { - if (select->prewhere_expression) - normalizeTreeImpl(select->prewhere_expression, finished_asts, current_asts, current_alias, level + 1); if (select->where_expression) normalizeTreeImpl(select->where_expression, finished_asts, current_asts, current_alias, level + 1); if (select->having_expression) @@ -1445,19 +1282,6 @@ void ExpressionAnalyzer::optimizeLimitBy() elems = unique_elems; } - -void ExpressionAnalyzer::makeSetsForIndex() -{ - if (storage && select_query && storage->supportsIndexForIn()) - { - if (select_query->where_expression) - makeSetsForIndexImpl(select_query->where_expression, storage->getSampleBlock()); - if (select_query->prewhere_expression) - makeSetsForIndexImpl(select_query->prewhere_expression, storage->getSampleBlock()); - } -} - - void ExpressionAnalyzer::tryMakeSetFromSubquery(const ASTPtr & subquery_or_table_name) { BlockIO res = interpretSubquery(subquery_or_table_name, context, subquery_depth + 1, {})->execute(); @@ -1476,56 +1300,6 @@ void ExpressionAnalyzer::tryMakeSetFromSubquery(const ASTPtr & subquery_or_table prepared_sets[subquery_or_table_name.get()] = std::move(set); } - -void ExpressionAnalyzer::makeSetsForIndexImpl(const ASTPtr & node, const Block & sample_block) -{ - for (auto & child : node->children) - { - /// Don't descent into subqueries. - if (typeid_cast(child.get())) - continue; - - /// Don't dive into lambda functions - const auto * func = typeid_cast(child.get()); - if (func && func->name == "lambda") - continue; - - makeSetsForIndexImpl(child, sample_block); - } - - const auto * func = typeid_cast(node.get()); - if (func && functionIsInOperator(func->name)) - { - const IAST & args = *func->arguments; - - if (storage && storage->mayBenefitFromIndexForIn(args.children.at(0))) - { - const ASTPtr & arg = args.children.at(1); - - if (!prepared_sets.count(arg.get())) /// Not already prepared. - { - if (typeid_cast(arg.get()) || typeid_cast(arg.get())) - { - if (settings.use_index_for_in_with_subqueries) - tryMakeSetFromSubquery(arg); - } - else - { - NamesAndTypesList temp_columns = source_columns; - temp_columns.insert(temp_columns.end(), columns_added_by_join.begin(), columns_added_by_join.end()); - ExpressionActionsPtr temp_actions = std::make_shared(temp_columns); - getRootActions(func->arguments->children.at(0), true, false, temp_actions); - - Block sample_block_with_calculated_columns = temp_actions->getSampleBlock(); - if (sample_block_with_calculated_columns.has(args.children.at(0)->getColumnName())) - makeExplicitSet(func, sample_block_with_calculated_columns, true); - } - } - } - } -} - - void ExpressionAnalyzer::makeSet(const ASTFunction * node, const Block & sample_block) { /** You need to convert the right argument to a set. @@ -1577,47 +1351,6 @@ void ExpressionAnalyzer::makeSet(const ASTFunction * node, const Block & sample_ SetPtr set = std::make_shared( SizeLimits(settings.max_rows_in_set, settings.max_bytes_in_set, settings.set_overflow_mode)); - /** The following happens for GLOBAL INs: - * - in the addExternalStorage function, the IN (SELECT ...) subquery is replaced with IN _data1, - * in the subquery_for_set object, this subquery is set as source and the temporary table _data1 as the table. - * - this function shows the expression IN_data1. - */ - if (!subquery_for_set.source && (!storage || !storage->isRemote())) - { - auto interpreter = interpretSubquery(arg, context, subquery_depth, {}); - subquery_for_set.source - = std::make_shared(interpreter->getSampleBlock(), [interpreter]() mutable { - return interpreter->execute().in; - }); - - /** Why is LazyBlockInputStream used? - * - * The fact is that when processing a query of the form - * SELECT ... FROM remote_test WHERE column GLOBAL IN (subquery), - * if the distributed remote_test table contains localhost as one of the servers, - * the query will be interpreted locally again (and not sent over TCP, as in the case of a remote server). - * - * The query execution pipeline will be: - * CreatingSets - * subquery execution, filling the temporary table with _data1 (1) - * CreatingSets - * reading from the table _data1, creating the set (2) - * read from the table subordinate to remote_test. - * - * (The second part of the pipeline under CreateSets is a reinterpretation of the query inside StorageDistributed, - * the query differs in that the database name and tables are replaced with subordinates, and the subquery is replaced with _data1.) - * - * But when creating the pipeline, when creating the source (2), it will be found that the _data1 table is empty - * (because the query has not started yet), and empty source will be returned as the source. - * And then, when the query is executed, an empty set will be created in step (2). - * - * Therefore, we make the initialization of step (2) lazy - * - so that it does not occur until step (1) is completed, on which the table will be populated. - * - * Note: this solution is not very good, you need to think better. - */ - } - subquery_for_set.set = set; prepared_sets[arg.get()] = set; } @@ -2096,11 +1829,10 @@ void ExpressionAnalyzer::getActionsImpl( void ExpressionAnalyzer::getAggregates(const ASTPtr & ast, ExpressionActionsPtr & actions) { - /// There can not be aggregate functions inside the WHERE and PREWHERE. - if (select_query - && (ast.get() == select_query->where_expression.get() || ast.get() == select_query->prewhere_expression.get())) + /// There can not be aggregate functions inside the WHERE. + if (select_query && ast.get() == select_query->where_expression.get()) { - assertNoAggregates(ast, "in WHERE or PREWHERE"); + assertNoAggregates(ast, "in WHERE"); return; } diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.h b/dbms/src/Interpreters/ExpressionAnalyzer.h index eddea6ee591..c990d8114b5 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.h +++ b/dbms/src/Interpreters/ExpressionAnalyzer.h @@ -113,13 +113,6 @@ class ExpressionAnalyzer : private boost::noncopyable PreparedSets getPreparedSets() { return prepared_sets; } - /** Tables that will need to be sent to remote servers for distributed query processing. - */ - const Tables & getExternalTables() const { return external_tables; } - - /// Create Set-s that we can from IN section to use the index on them. - void makeSetsForIndex(); - private: ASTPtr ast; ASTSelectQuery * select_query; @@ -146,9 +139,6 @@ class ExpressionAnalyzer : private boost::noncopyable NamesAndTypesList aggregation_keys; AggregateDescriptions aggregate_descriptions; - /// Do I need to prepare for execution global subqueries when analyzing the query. - bool do_global; - SubqueriesForSets subqueries_for_sets; PreparedSets prepared_sets; @@ -173,10 +163,6 @@ class ExpressionAnalyzer : private boost::noncopyable using SetOfASTs = std::set; using MapOfASTs = std::map; - /// All new temporary tables obtained by performing the GLOBAL IN/JOIN subqueries. - Tables external_tables; - size_t external_table_id = 1; - /** Remove all unnecessary columns from the list of all available columns of the table (`columns`). * At the same time, form a set of unknown columns (`unknown_required_source_columns`), * as well as the columns added by JOIN (`columns_added_by_join`). @@ -224,18 +210,6 @@ class ExpressionAnalyzer : private boost::noncopyable void executeScalarSubqueries(); void executeScalarSubqueriesImpl(ASTPtr & ast); - /// Find global subqueries in the GLOBAL IN/JOIN sections. Fills in external_tables. - void initGlobalSubqueriesAndExternalTables(); - void initGlobalSubqueries(ASTPtr & ast); - - /// Finds in the query the usage of external tables (as table identifiers). Fills in external_tables. - void findExternalTables(ASTPtr & ast); - - /** Initialize InterpreterSelectQuery for a subquery in the GLOBAL IN/JOIN section, - * create a temporary table of type Memory and store it in the external_tables dictionary. - */ - void addExternalStorage(ASTPtr & subquery_or_table_name); - void addJoinAction(ExpressionActionsPtr & actions, bool only_types) const; struct ScopeStack; @@ -279,13 +253,11 @@ class ExpressionAnalyzer : private boost::noncopyable void makeExplicitSet(const ASTFunction * node, const Block & sample_block, bool create_ordered_set); /** - * Create Set from a subuqery or a table expression in the query. The created set is suitable for using the index. + * Create Set from a subquery or a table expression in the query. The created set is suitable for using the index. * The set will not be created if its size hits the limit. */ void tryMakeSetFromSubquery(const ASTPtr & subquery_or_table_name); - void makeSetsForIndexImpl(const ASTPtr & node, const Block & sample_block); - /** Translate qualified names such as db.table.column, table.column, table_alias.column * to unqualified names. This is done in a poor transitional way: * only one ("main") table is supported. Ambiguity is not detected or resolved. diff --git a/dbms/src/Interpreters/InterpreterDropQuery.cpp b/dbms/src/Interpreters/InterpreterDropQuery.cpp index 259a97c10d5..07f8cb5dada 100644 --- a/dbms/src/Interpreters/InterpreterDropQuery.cpp +++ b/dbms/src/Interpreters/InterpreterDropQuery.cpp @@ -68,29 +68,6 @@ BlockIO InterpreterDropQuery::execute() return {}; } - /// Drop temporary table. - if (drop.database.empty() || drop.temporary) - { - StoragePtr table - = (context.hasSessionContext() ? context.getSessionContext() : context).tryRemoveExternalTable(drop.table); - if (table) - { - if (drop.database.empty() && !drop.temporary) - { - LOG_WARNING( - (&Poco::Logger::get("InterpreterDropQuery")), - "It is recommended to use `DROP TEMPORARY TABLE` to delete temporary tables"); - } - table->shutdown(); - /// If table was already dropped by anyone, an exception will be thrown - auto table_lock = table->lockExclusively(context.getCurrentQueryId(), drop.lock_timeout); - /// Delete table data - table->drop(); - table->is_dropped = true; - return {}; - } - } - String database_name = drop.database.empty() ? current_database : drop.database; String database_name_escaped = escapeForFileName(database_name); @@ -223,13 +200,10 @@ BlockIO InterpreterDropQuery::execute() } -void InterpreterDropQuery::checkAccess(const ASTDropQuery & drop) +void InterpreterDropQuery::checkAccess(const ASTDropQuery & /*drop*/) { const Settings & settings = context.getSettingsRef(); - auto readonly = settings.readonly; - - /// It's allowed to drop temporary tables. - if (!readonly || (drop.database.empty() && context.tryGetExternalTable(drop.table) && readonly >= 2)) + if (!settings.readonly) { return; } diff --git a/dbms/src/Interpreters/InterpreterExistsQuery.cpp b/dbms/src/Interpreters/InterpreterExistsQuery.cpp index d60f2086c38..c689fafad44 100644 --- a/dbms/src/Interpreters/InterpreterExistsQuery.cpp +++ b/dbms/src/Interpreters/InterpreterExistsQuery.cpp @@ -43,7 +43,8 @@ Block InterpreterExistsQuery::getSampleBlock() BlockInputStreamPtr InterpreterExistsQuery::executeImpl() { const ASTExistsQuery & ast = typeid_cast(*query_ptr); - bool res = ast.temporary ? context.isExternalTableExist(ast.table) : context.isTableExist(ast.database, ast.table); + RUNTIME_CHECK_MSG(!ast.temporary, "external table is not supported"); + bool res = context.isTableExist(ast.database, ast.table); return std::make_shared( Block{{ColumnUInt8::create(1, res), std::make_shared(), "result"}}); diff --git a/dbms/src/Interpreters/InterpreterInsertQuery.cpp b/dbms/src/Interpreters/InterpreterInsertQuery.cpp index 8a3b4253715..492e1dab93e 100644 --- a/dbms/src/Interpreters/InterpreterInsertQuery.cpp +++ b/dbms/src/Interpreters/InterpreterInsertQuery.cpp @@ -17,7 +17,6 @@ #include #include #include -#include #include #include #include @@ -28,6 +27,7 @@ #include #include #include +#include #include #include @@ -118,16 +118,9 @@ BlockIO InterpreterInsertQuery::execute() /// We create a pipeline of several streams, into which we will write data. BlockOutputStreamPtr out; - out = std::make_shared( - query.database, - query.table, + out = std::make_shared( table, - context, query_ptr, - query.no_destination); - - out = std::make_shared( - out, getSampleBlock(query, table), required_columns, table->getColumns().defaults, @@ -183,12 +176,10 @@ BlockIO InterpreterInsertQuery::execute() } -void InterpreterInsertQuery::checkAccess(const ASTInsertQuery & query) +void InterpreterInsertQuery::checkAccess(const ASTInsertQuery &) { const Settings & settings = context.getSettingsRef(); - auto readonly = settings.readonly; - - if (!readonly || (query.database.empty() && context.tryGetExternalTable(query.table) && readonly >= 2)) + if (!settings.readonly) { return; } diff --git a/dbms/src/Interpreters/InterpreterSelectQuery.cpp b/dbms/src/Interpreters/InterpreterSelectQuery.cpp index 94f7c880ad2..f9c730afb2a 100644 --- a/dbms/src/Interpreters/InterpreterSelectQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSelectQuery.cpp @@ -86,7 +86,6 @@ extern const int TOO_DEEP_SUBQUERIES; extern const int THERE_IS_NO_COLUMN; extern const int SAMPLING_NOT_SUPPORTED; extern const int ILLEGAL_FINAL; -extern const int ILLEGAL_PREWHERE; extern const int TOO_MANY_COLUMNS; extern const int LOGICAL_ERROR; extern const int NOT_IMPLEMENTED; @@ -204,17 +203,6 @@ void InterpreterSelectQuery::init(const Names & required_result_column_names) throw Exception( (!input && storage) ? "Storage " + storage->getName() + " doesn't support FINAL" : "Illegal FINAL", ErrorCodes::ILLEGAL_FINAL); - - if (query.prewhere_expression && (input || !storage || !storage->supportsPrewhere())) - throw Exception( - (!input && storage) ? "Storage " + storage->getName() + " doesn't support PREWHERE" - : "Illegal PREWHERE", - ErrorCodes::ILLEGAL_PREWHERE); - - /// Save the new temporary tables in the query context - for (const auto & it : query_analyzer->getExternalTables()) - if (!context.tryGetExternalTable(it.first)) - context.addExternalTable(it.first, it.second); } } @@ -744,27 +732,15 @@ QueryProcessingStage::Enum InterpreterSelectQuery::executeFetchColumns(Pipeline size_t limit_offset = 0; getLimitLengthAndOffset(query, limit_length, limit_offset); - /** With distributed query processing, almost no computations are done in the threads, - * but wait and receive data from remote servers. - * If we have 20 remote servers, and max_threads = 8, then it would not be very good - * connect and ask only 8 servers at a time. - * To simultaneously query more remote servers, - * instead of max_threads, max_distributed_connections is used. - */ - if (storage && storage->isRemote()) - { - max_streams = settings.max_distributed_connections; - } - size_t max_block_size = settings.max_block_size; /** Optimization - if not specified DISTINCT, WHERE, GROUP, HAVING, ORDER, LIMIT BY but LIMIT is specified, and limit + offset < max_block_size, * then as the block size we will use limit + offset (not to read more from the table than requested), * and also set the number of threads to 1. */ - if (!query.distinct && !query.prewhere_expression && !query.where_expression && !query.group_expression_list - && !query.having_expression && !query.order_expression_list && !query.limit_by_expression_list - && query.limit_length && !query_analyzer->hasAggregation() && limit_length + limit_offset < max_block_size) + if (!query.distinct && !query.where_expression && !query.group_expression_list && !query.having_expression + && !query.order_expression_list && !query.limit_by_expression_list && query.limit_length + && !query_analyzer->hasAggregation() && limit_length + limit_offset < max_block_size) { max_block_size = limit_length + limit_offset; max_streams = 1; @@ -794,8 +770,6 @@ QueryProcessingStage::Enum InterpreterSelectQuery::executeFetchColumns(Pipeline if (max_streams == 0) throw Exception("Logical error: zero number of streams requested", ErrorCodes::LOGICAL_ERROR); - query_analyzer->makeSetsForIndex(); - SelectQueryInfo query_info; query_info.query = query_ptr; query_info.sets = query_analyzer->getPreparedSets(); diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index e768e4edf7a..8c7cc817793 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -13,9 +13,8 @@ // limitations under the License. #include -#include -#include #include +#include #include #include #include @@ -26,9 +25,15 @@ #include #include +#include namespace DB { +namespace FailPoints +{ +extern const char force_join_v2_probe_enable_lm[]; +extern const char force_join_v2_probe_disable_lm[]; +} // namespace FailPoints namespace { @@ -98,15 +103,6 @@ StringCollatorKind getStringCollatorKind(const TiDB::TiDBCollators & collators) } } -IColumn::Offsets initBaseOffsets(bool has_other_condition, size_t max_block_size) -{ - if (!has_other_condition) - return {}; - IColumn::Offsets offset(max_block_size); - std::iota(offset.begin(), offset.end(), 0ULL); - return offset; -} - } // namespace HashJoin::HashJoin( @@ -130,7 +126,6 @@ HashJoin::HashJoin( , log(Logger::get(join_req_id)) , has_other_condition(non_equal_conditions.other_cond_expr != nullptr) , output_columns(output_columns_) - , base_offsets(initBaseOffsets(has_other_condition, settings.max_block_size)) { RUNTIME_ASSERT(key_names_left.size() == key_names_right.size()); output_block = Block(output_columns); @@ -221,7 +216,7 @@ void HashJoin::initRowLayoutAndHashJoinMethod() { raw_required_key_flag[i] = true; raw_required_key_index_set.insert(index); - row_layout.raw_required_key_column_indexes.push_back({index, key_columns[i].is_nullable}); + row_layout.raw_key_column_indexes.push_back({index, key_columns[i].is_nullable}); continue; } } @@ -241,7 +236,7 @@ void HashJoin::initRowLayoutAndHashJoinMethod() key_names_right.swap(new_key_names_right); } - row_layout.other_required_count_for_other_condition = 0; + row_layout.other_column_count_for_other_condition = 0; size_t columns = right_sample_block_pruned.columns(); BoolVec required_columns_flag(columns); for (size_t i = 0; i < columns; ++i) @@ -254,16 +249,16 @@ void HashJoin::initRowLayoutAndHashJoinMethod() auto & c = right_sample_block_pruned.safeGetByPosition(i); if (required_columns_names_set_for_other_condition.contains(c.name)) { - ++row_layout.other_required_count_for_other_condition; + ++row_layout.other_column_count_for_other_condition; required_columns_flag[i] = true; if (c.column->valuesHaveFixedSize()) { row_layout.other_column_fixed_size += c.column->sizeOfValueIfFixed(); - row_layout.other_required_column_indexes.push_back({i, true}); + row_layout.other_column_indexes.push_back({i, true}); } else { - row_layout.other_required_column_indexes.push_back({i, false}); + row_layout.other_column_indexes.push_back({i, false}); } } } @@ -275,21 +270,23 @@ void HashJoin::initRowLayoutAndHashJoinMethod() if (c.column->valuesHaveFixedSize()) { row_layout.other_column_fixed_size += c.column->sizeOfValueIfFixed(); - row_layout.other_required_column_indexes.push_back({i, true}); + row_layout.other_column_indexes.push_back({i, true}); } else { - row_layout.other_required_column_indexes.push_back({i, false}); + row_layout.other_column_indexes.push_back({i, false}); } + RUNTIME_CHECK_MSG( + output_block_after_finalize.has(c.name), + "output_block_after_finalize does not contain {}", + c.name); } - RUNTIME_CHECK( - row_layout.raw_required_key_column_indexes.size() + row_layout.other_required_column_indexes.size() == columns); + RUNTIME_CHECK(row_layout.raw_key_column_indexes.size() + row_layout.other_column_indexes.size() == columns); } void HashJoin::initBuild(const Block & sample_block, size_t build_concurrency_) { RUNTIME_CHECK_MSG(!build_initialized, "Logical error: Join build has been initialized"); - build_initialized = true; RUNTIME_CHECK_MSG(isFinalize(), "join should be finalized first"); right_sample_block = materializeBlock(sample_block); @@ -314,13 +311,14 @@ void HashJoin::initBuild(const Block & sample_block, size_t build_concurrency_) build_workers_data[i].key_getter = createHashJoinKeyGetter(method, collators); for (size_t i = 0; i < JOIN_BUILD_PARTITION_COUNT + 1; ++i) multi_row_containers.emplace_back(std::make_unique()); + + build_initialized = true; } void HashJoin::initProbe(const Block & sample_block, size_t probe_concurrency_) { RUNTIME_CHECK_MSG(build_initialized, "join build should be initialized first"); RUNTIME_CHECK_MSG(!probe_initialized, "Logical error: Join probe has been initialized"); - probe_initialized = true; RUNTIME_CHECK_MSG(isFinalize(), "join should be finalized first"); left_sample_block = materializeBlock(sample_block); @@ -348,6 +346,27 @@ void HashJoin::initProbe(const Block & sample_block, size_t probe_concurrency_) all_sample_block_pruned.insert(std::move(new_column)); } + + size_t all_columns = all_sample_block_pruned.columns(); + output_column_indexes.reserve(all_columns); + size_t output_columns = 0; + for (size_t i = 0; i < all_columns; ++i) + { + ssize_t output_index = -1; + const auto & name = all_sample_block_pruned.safeGetByPosition(i).name; + if (output_block_after_finalize.has(name)) + { + output_index = output_block_after_finalize.getPositionByName(name); + ++output_columns; + } + output_column_indexes.push_back(output_index); + } + RUNTIME_CHECK_MSG( + output_columns == output_block_after_finalize.columns(), + "output columns {} in all_sample_block_pruned != columns {} in output_block_after_finalize", + output_columns, + output_block_after_finalize.columns()); + if (has_other_condition) { left_required_flag_for_other_condition.resize(left_sample_block_pruned.columns()); @@ -367,6 +386,8 @@ void HashJoin::initProbe(const Block & sample_block, size_t probe_concurrency_) probe_concurrency = probe_concurrency_; active_probe_worker = probe_concurrency; probe_workers_data.resize(probe_concurrency); + + probe_initialized = true; } bool HashJoin::finishOneBuildRow(size_t stream_index) @@ -423,16 +444,39 @@ void HashJoin::workAfterBuildRowFinish() settings.probe_enable_prefetch_threshold, enable_tagged_pointer); + const size_t lm_size_threshold = 32; + bool late_materialization = false; + size_t avg_lm_row_size = 0; + if (has_other_condition + && row_layout.other_column_count_for_other_condition < row_layout.other_column_indexes.size()) + { + size_t total_lm_row_size = 0; + size_t total_lm_row_count = 0; + for (size_t i = 0; i < build_concurrency; ++i) + { + total_lm_row_size += build_workers_data[i].lm_row_size; + total_lm_row_count += build_workers_data[i].lm_row_count; + } + avg_lm_row_size = total_lm_row_count == 0 ? 0 : total_lm_row_size / total_lm_row_count; + late_materialization = avg_lm_row_size >= lm_size_threshold; + } + fiu_do_on(FailPoints::force_join_v2_probe_enable_lm, { late_materialization = true; }); + fiu_do_on(FailPoints::force_join_v2_probe_disable_lm, { late_materialization = false; }); + join_probe_helper = std::make_unique(this, late_materialization); + LOG_DEBUG( log, - "allocate pointer table cost {}ms, rows {}, pointer table size {}, added column num {}, enable prefetch {}, " - "enable tagged pointer {}", + "allocate pointer table and init join probe helerp cost {}ms, rows {}, pointer table size {}, " + "added column num {}, enable prefetch {}, enable tagged pointer {}, " + "enable late materialization {}(avg size {})", watch.elapsedMilliseconds(), all_build_row_count, pointer_table.getPointerTableSize(), right_sample_block_pruned.columns(), pointer_table.enableProbePrefetch(), - pointer_table.enableTaggedPointer()); + pointer_table.enableTaggedPointer(), + late_materialization, + avg_lm_row_size); } void HashJoin::buildRowFromBlock(const Block & b, size_t stream_index) @@ -452,8 +496,6 @@ void HashJoin::buildRowFromBlock(const Block & b, size_t stream_index) /// Note: this variable can't be removed because it will take smart pointers' lifecycle to the end of this function. Columns materialized_columns; ColumnRawPtrs key_columns = extractAndMaterializeKeyColumns(block, materialized_columns, key_names_right); - /// Some useless columns maybe key columns so they must be removed after extracting key columns. - removeUselessColumn(block); /// We will insert to the map only keys, where all components are not NULL. ColumnPtr null_map_holder; @@ -462,6 +504,8 @@ void HashJoin::buildRowFromBlock(const Block & b, size_t stream_index) /// Reuse null_map to record the filtered rows, the rows contains NULL or does not /// match the join filter will not insert to the maps recordFilteredRows(block, non_equal_conditions.right_filter_column, null_map_holder, null_map); + /// Some useless columns maybe key columns and filter column so they must be removed after extracting key columns and filter column. + removeUselessColumn(block); /// Rare case, when joined columns are constant. To avoid code bloat, simply materialize them. block = materializeBlock(block); @@ -476,6 +520,8 @@ void HashJoin::buildRowFromBlock(const Block & b, size_t stream_index) assertBlocksHaveEqualStructure(block, right_sample_block_pruned, "Join Build"); + bool check_lm_row_size = has_other_condition + && row_layout.other_column_count_for_other_condition < row_layout.other_column_indexes.size(); insertBlockToRowContainers( method, needRecordNotInsertRows(kind), @@ -485,7 +531,8 @@ void HashJoin::buildRowFromBlock(const Block & b, size_t stream_index) null_map, row_layout, multi_row_containers, - build_workers_data[stream_index]); + build_workers_data[stream_index], + check_lm_row_size); build_workers_data[stream_index].build_time += watch.elapsedMilliseconds(); } @@ -554,120 +601,10 @@ Block HashJoin::probeBlock(JoinProbeContext & context, size_t stream_index) row_layout); auto & wd = probe_workers_data[stream_index]; - size_t left_columns = left_sample_block_pruned.columns(); - size_t right_columns = right_sample_block_pruned.columns(); - if (!wd.result_block) - { - for (size_t i = 0; i < left_columns + right_columns; ++i) - { - ColumnWithTypeAndName new_column = all_sample_block_pruned.safeGetByPosition(i).cloneEmpty(); - new_column.column->assumeMutable()->reserveAlign(settings.max_block_size, FULL_VECTOR_SIZE_AVX2); - wd.result_block.insert(std::move(new_column)); - } - } - - bool late_materialization = false; - if (has_other_condition) - { - late_materialization - = row_layout.other_required_count_for_other_condition < row_layout.other_required_column_indexes.size(); - } - - MutableColumns added_columns; - if (late_materialization) - { - for (auto [column_index, _] : row_layout.raw_required_key_column_indexes) - added_columns.emplace_back( - wd.result_block.safeGetByPosition(left_columns + column_index).column->assumeMutable()); - for (size_t i = 0; i < row_layout.other_required_count_for_other_condition; ++i) - { - size_t column_index = row_layout.other_required_column_indexes[i].first; - added_columns.emplace_back( - wd.result_block.safeGetByPosition(left_columns + column_index).column->assumeMutable()); - } - } - else - { - added_columns.resize(right_columns); - for (size_t i = 0; i < right_columns; ++i) - added_columns[i] = wd.result_block.safeGetByPosition(left_columns + i).column->assumeMutable(); - } - - Stopwatch watch; - joinProbeBlock( - context, - wd, - method, - kind, - late_materialization, - non_equal_conditions, - settings, - pointer_table, - row_layout, - added_columns, - wd.result_block.rows()); - wd.probe_hash_table_time += watch.elapsedFromLastTime(); - if (context.isCurrentProbeFinished()) + Block res = join_probe_helper->probe(context, wd); + if (context.isAllFinished()) wd.probe_handle_rows += context.rows; - - if (late_materialization) - { - size_t idx = 0; - for (auto [column_index, _] : row_layout.raw_required_key_column_indexes) - wd.result_block.safeGetByPosition(left_columns + column_index).column = std::move(added_columns[idx++]); - for (size_t i = 0; i < row_layout.other_required_count_for_other_condition; ++i) - { - size_t column_index = row_layout.other_required_column_indexes[i].first; - wd.result_block.safeGetByPosition(left_columns + column_index).column = std::move(added_columns[idx++]); - } - } - else - { - for (size_t i = 0; i < right_columns; ++i) - wd.result_block.safeGetByPosition(left_columns + i).column = std::move(added_columns[i]); - } - - if (wd.selective_offsets.empty()) - return output_block_after_finalize; - - if (has_other_condition) - { - // Always using late materialization for left side - for (size_t i = 0; i < left_columns; ++i) - { - if (!left_required_flag_for_other_condition[i]) - continue; - wd.result_block.safeGetByPosition(i).column->assumeMutable()->insertSelectiveFrom( - *context.block.safeGetByPosition(i).column.get(), - wd.selective_offsets); - } - } - else - { - for (size_t i = 0; i < left_columns; ++i) - { - wd.result_block.safeGetByPosition(i).column->assumeMutable()->insertSelectiveFrom( - *context.block.safeGetByPosition(i).column.get(), - wd.selective_offsets); - } - } - - wd.replicate_time += watch.elapsedFromLastTime(); - - if (has_other_condition) - { - auto res_block = handleOtherConditions(context, wd, late_materialization); - wd.other_condition_time += watch.elapsedFromLastTime(); - return res_block; - } - - if (wd.result_block.rows() >= settings.max_block_size) - { - auto res_block = removeUselessColumnForOutput(wd.result_block); - wd.result_block = {}; - return res_block; - } - return output_block_after_finalize; + return res; } Block HashJoin::probeLastResultBlock(size_t stream_index) @@ -701,14 +638,31 @@ void HashJoin::removeUselessColumn(Block & block) const Block HashJoin::removeUselessColumnForOutput(const Block & block) const { - // remove useless columns and adjust the order of columns - Block projected_block; - for (const auto & name_and_type : output_columns_after_finalize) + RUNTIME_CHECK(probe_initialized); + RUNTIME_CHECK(block.columns() == all_sample_block_pruned.columns()); + Block output_block = output_block_after_finalize.cloneEmpty(); + size_t columns = block.columns(); + for (size_t i = 0; i < columns; ++i) { - const auto & column = block.getByName(name_and_type.name); - projected_block.insert(std::move(column)); + if (output_column_indexes[i] == -1) + continue; + output_block.safeGetByPosition(output_column_indexes[i]) = block.safeGetByPosition(i); + } + return output_block; +} + +void HashJoin::initOutputBlock(Block & block) const +{ + if (!block) + { + size_t output_columns = output_block_after_finalize.columns(); + for (size_t i = 0; i < output_columns; ++i) + { + ColumnWithTypeAndName new_column = output_block_after_finalize.safeGetByPosition(i).cloneEmpty(); + new_column.column->assumeMutable()->reserveAlign(settings.max_block_size, FULL_VECTOR_SIZE_AVX2); + block.insert(std::move(new_column)); + } } - return projected_block; } void HashJoin::finalize(const Names & parent_require) @@ -826,243 +780,4 @@ void HashJoin::finalize(const Names & parent_require) finalized = true; } -Block HashJoin::handleOtherConditions(JoinProbeContext & context, JoinProbeWorkerData & wd, bool late_materialization) -{ - size_t left_columns = left_sample_block_pruned.columns(); - size_t right_columns = right_sample_block_pruned.columns(); - // Some columns in wd.result_block may be empty so need to create another block to execute other condition expressions - Block exec_block; - for (size_t i = 0; i < left_columns; ++i) - { - if (left_required_flag_for_other_condition[i]) - exec_block.insert(wd.result_block.getByPosition(i)); - } - if (late_materialization) - { - for (auto [column_index, _] : row_layout.raw_required_key_column_indexes) - exec_block.insert(wd.result_block.getByPosition(left_columns + column_index)); - for (size_t i = 0; i < row_layout.other_required_count_for_other_condition; ++i) - { - size_t column_index = row_layout.other_required_column_indexes[i].first; - exec_block.insert(wd.result_block.getByPosition(left_columns + column_index)); - } - } - else - { - for (size_t i = 0; i < right_columns; ++i) - exec_block.insert(wd.result_block.getByPosition(left_columns + i)); - } - - non_equal_conditions.other_cond_expr->execute(exec_block); - - size_t rows = exec_block.rows(); - RUNTIME_CHECK_MSG( - rows <= settings.max_block_size, - "exec_block rows {} > max_block_size {}", - rows, - settings.max_block_size); - - wd.filter.clear(); - mergeNullAndFilterResult(exec_block, wd.filter, non_equal_conditions.other_cond_name, false); - - size_t output_columns = output_block_after_finalize.columns(); - - auto init_result_block_for_other_condition = [&]() { - wd.result_block_for_other_condition = {}; - for (size_t i = 0; i < output_columns; ++i) - { - ColumnWithTypeAndName new_column = output_block_after_finalize.safeGetByPosition(i).cloneEmpty(); - new_column.column->assumeMutable()->reserveAlign(settings.max_block_size, FULL_VECTOR_SIZE_AVX2); - wd.result_block_for_other_condition.insert(std::move(new_column)); - } - }; - - if (!wd.result_block_for_other_condition) - init_result_block_for_other_condition(); - - RUNTIME_CHECK_MSG( - wd.result_block_for_other_condition.rows() < settings.max_block_size, - "result_block_for_other_condition rows {} >= max_block_size {}", - wd.result_block_for_other_condition.rows(), - settings.max_block_size); - size_t remaining_insert_size = settings.max_block_size - wd.result_block_for_other_condition.rows(); - size_t result_size = countBytesInFilter(wd.filter); - - bool filter_offsets_is_initialized = false; - auto init_filter_offsets = [&]() { - RUNTIME_CHECK(wd.filter.size() == rows); - wd.filter_offsets.clear(); - wd.filter_offsets.reserve(result_size); - filterImpl(&wd.filter[0], &wd.filter[rows], &base_offsets[0], wd.filter_offsets); - RUNTIME_CHECK(wd.filter_offsets.size() == result_size); - filter_offsets_is_initialized = true; - }; - - bool filter_selective_offsets_is_initialized = false; - auto init_filter_selective_offsets = [&]() { - RUNTIME_CHECK(wd.selective_offsets.size() == rows); - wd.filter_selective_offsets.clear(); - wd.filter_selective_offsets.reserve(result_size); - filterImpl(&wd.filter[0], &wd.filter[rows], &wd.selective_offsets[0], wd.filter_selective_offsets); - RUNTIME_CHECK(wd.filter_selective_offsets.size() == result_size); - filter_selective_offsets_is_initialized = true; - }; - - bool filter_row_ptrs_for_lm_is_initialized = false; - auto init_filter_row_ptrs_for_lm = [&]() { - RUNTIME_CHECK(wd.row_ptrs_for_lm.size() == rows); - wd.filter_row_ptrs_for_lm.clear(); - wd.filter_row_ptrs_for_lm.reserve(result_size); - filterImpl(&wd.filter[0], &wd.filter[rows], &wd.row_ptrs_for_lm[0], wd.filter_row_ptrs_for_lm); - RUNTIME_CHECK(wd.filter_row_ptrs_for_lm.size() == result_size); - filter_row_ptrs_for_lm_is_initialized = true; - }; - - auto fill_block = [&](size_t start, size_t length) { - if (late_materialization) - { - for (auto [column_index, _] : row_layout.raw_required_key_column_indexes) - { - const auto & name = right_sample_block_pruned.getByPosition(column_index).name; - if (!wd.result_block_for_other_condition.has(name)) - continue; - if unlikely (!filter_offsets_is_initialized) - init_filter_offsets(); - auto & des_column = wd.result_block_for_other_condition.getByName(name); - auto & src_column = exec_block.getByName(name); - des_column.column->assumeMutable() - ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); - } - for (size_t i = 0; i < row_layout.other_required_count_for_other_condition; ++i) - { - size_t column_index = row_layout.other_required_column_indexes[i].first; - const auto & name = right_sample_block_pruned.getByPosition(column_index).name; - if (!wd.result_block_for_other_condition.has(name)) - continue; - if unlikely (!filter_offsets_is_initialized) - init_filter_offsets(); - auto & des_column = wd.result_block_for_other_condition.getByName(name); - auto & src_column = exec_block.getByName(name); - des_column.column->assumeMutable() - ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); - } - if (!filter_row_ptrs_for_lm_is_initialized) - init_filter_row_ptrs_for_lm(); - - std::vector actual_column_indexes; - for (size_t i = row_layout.other_required_count_for_other_condition; - i < row_layout.other_required_column_indexes.size(); - ++i) - { - size_t column_index = row_layout.other_required_column_indexes[i].first; - const auto & name = right_sample_block_pruned.getByPosition(column_index).name; - size_t actual_column_index = wd.result_block_for_other_condition.getPositionByName(name); - actual_column_indexes.emplace_back(actual_column_index); - } - - constexpr size_t step = 256; - for (size_t i = start; i < start + length; i += step) - { - size_t end = i + step > start + length ? start + length : i + step; - wd.insert_batch.clear(); - wd.insert_batch.insert(&wd.row_ptrs_for_lm[i], &wd.row_ptrs_for_lm[end]); - for (auto column_index : actual_column_indexes) - { - auto & des_column = wd.result_block_for_other_condition.getByPosition(column_index); - des_column.column->assumeMutable()->deserializeAndInsertFromPos(wd.insert_batch, true); - } - } - for (auto column_index : actual_column_indexes) - { - auto & des_column = wd.result_block_for_other_condition.getByPosition(column_index); - des_column.column->assumeMutable()->flushNTAlignBuffer(); - } - } - else - { - for (size_t i = 0; i < right_columns; ++i) - { - const auto & name = right_sample_block_pruned.getByPosition(i).name; - if (!wd.result_block_for_other_condition.has(name)) - continue; - if unlikely (!filter_offsets_is_initialized) - init_filter_offsets(); - auto & des_column = wd.result_block_for_other_condition.getByName(name); - auto & src_column = exec_block.getByName(name); - des_column.column->assumeMutable() - ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); - } - } - - for (size_t i = 0; i < left_columns; ++i) - { - const auto & name = left_sample_block_pruned.getByPosition(i).name; - if (!wd.result_block_for_other_condition.has(name)) - continue; - auto & des_column = wd.result_block_for_other_condition.getByName(name); - if (left_required_flag_for_other_condition[i]) - { - if unlikely (!filter_offsets_is_initialized && !filter_selective_offsets_is_initialized) - init_filter_selective_offsets(); - if (filter_offsets_is_initialized) - { - auto & src_column = exec_block.getByName(name); - des_column.column->assumeMutable() - ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); - } - else - { - auto & src_column = context.block.safeGetByPosition(i); - des_column.column->assumeMutable()->insertSelectiveRangeFrom( - *src_column.column.get(), - wd.filter_selective_offsets, - start, - length); - } - continue; - } - if unlikely (!filter_selective_offsets_is_initialized) - init_filter_selective_offsets(); - auto & src_column = context.block.safeGetByPosition(i); - des_column.column->assumeMutable() - ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_selective_offsets, start, length); - } - }; - - size_t length = result_size > remaining_insert_size ? remaining_insert_size : result_size; - fill_block(0, length); - - Block res_block; - if (result_size >= remaining_insert_size) - { - res_block = wd.result_block_for_other_condition; - init_result_block_for_other_condition(); - if (result_size > remaining_insert_size) - fill_block(remaining_insert_size, result_size - remaining_insert_size); - } - else - { - res_block = output_block_after_finalize; - } - - exec_block.clear(); - /// Remove the new column added from other condition expressions. - removeUselessColumn(wd.result_block); - - assertBlocksHaveEqualStructure( - wd.result_block, - all_sample_block_pruned, - "Join Probe reuses result_block for other condition"); - - /// Clear the data in result_block. - for (size_t i = 0; i < wd.result_block.columns(); ++i) - { - auto column = wd.result_block.getByPosition(i).column->assumeMutable(); - column->popBack(column->size()); - wd.result_block.getByPosition(i).column = std::move(column); - } - - return res_block; -} - } // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.h b/dbms/src/Interpreters/JoinV2/HashJoin.h index 105e7d137b9..d00a4e9c661 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.h +++ b/dbms/src/Interpreters/JoinV2/HashJoin.h @@ -27,6 +27,8 @@ #include #include +#include + namespace DB { @@ -61,8 +63,10 @@ class HashJoin Block probeLastResultBlock(size_t stream_index); void removeUselessColumn(Block & block) const; + /// Block's schema must be all_sample_block_pruned. Block removeUselessColumnForOutput(const Block & block) const; + void initOutputBlock(Block & block) const; const Block & getOutputBlock() const { return finalized ? output_block_after_finalize : output_block; } const Names & getRequiredColumns() const { return required_columns; } void finalize(const Names & parent_require); @@ -78,9 +82,9 @@ class HashJoin void workAfterBuildRowFinish(); - Block handleOtherConditions(JoinProbeContext & context, JoinProbeWorkerData & wd, bool late_materialization); - private: + friend JoinProbeBlockHelper; + const ASTTableJoin::Kind kind; const String join_req_id; @@ -118,6 +122,10 @@ class HashJoin /// Block with columns from left-side and right-side table. Block all_sample_block_pruned; + /// Maps each column in all_sample_block_pruned to its index in output_block_after_finalize. + // <= 0 means the column is not in the output_block_after_finalize. + std::vector output_column_indexes; + NamesAndTypes output_columns; Block output_block; NamesAndTypes output_columns_after_finalize; @@ -131,22 +139,22 @@ class HashJoin /// Row containers std::vector> multi_row_containers; - /// Build phase + /// Build row phase size_t build_concurrency = 0; std::vector build_workers_data; std::atomic active_build_worker = 0; + HashJoinPointerTable pointer_table; + /// Probe phase size_t probe_concurrency = 0; std::vector probe_workers_data; std::atomic active_probe_worker = 0; - - HashJoinPointerTable pointer_table; + std::unique_ptr join_probe_helper; const JoinProfileInfoPtr profile_info = std::make_shared(); /// For other condition - const IColumn::Offsets base_offsets; BoolVec left_required_flag_for_other_condition; }; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp b/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp index 8cbac7c2a54..bfbb240314b 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include namespace DB @@ -30,7 +31,8 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( ConstNullMapPtr null_map, const HashJoinRowLayout & row_layout, std::vector> & multi_row_containers, - JoinBuildWorkerData & wd) + JoinBuildWorkerData & wd, + bool check_lm_row_size) { using KeyGetterType = typename KeyGetter::Type; using Hash = typename KeyGetter::Hash; @@ -38,7 +40,7 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( static_assert(sizeof(HashValueType) <= sizeof(decltype(wd.hashes)::value_type)); auto & key_getter = *static_cast(wd.key_getter.get()); - key_getter.reset(key_columns, row_layout.raw_required_key_column_indexes.size()); + key_getter.reset(key_columns, row_layout.raw_key_column_indexes.size()); wd.row_sizes.clear(); wd.row_sizes.resize_fill(rows, row_layout.other_column_fixed_size); @@ -46,13 +48,32 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( /// The last partition is used to hold rows with null join key. constexpr size_t part_count = JOIN_BUILD_PARTITION_COUNT + 1; wd.partition_row_sizes.clear(); - wd.partition_row_sizes.resize_fill(part_count, 0); + wd.partition_row_sizes.resize_fill_zero(part_count); wd.partition_row_count.clear(); - wd.partition_row_count.resize_fill(part_count, 0); + wd.partition_row_count.resize_fill_zero(part_count); wd.partition_last_row_index.clear(); wd.partition_last_row_index.resize_fill(part_count, -1); - for (const auto & [index, is_fixed_size] : row_layout.other_required_column_indexes) + if (check_lm_row_size) + { + wd.lm_row_count += rows; + for (size_t i = row_layout.other_column_count_for_other_condition; i < row_layout.other_column_indexes.size(); + ++i) + { + size_t index = row_layout.other_column_indexes[i].first; + const auto & column = block.getByPosition(index).column; + if (const auto * column_string = typeid_cast(column.get())) + { + wd.lm_row_size += column_string->getChars().size() + sizeof(UInt32) * column_string->size(); + } + else + { + wd.lm_row_size += column->byteSize(); + } + } + } + + for (const auto & [index, is_fixed_size] : row_layout.other_column_indexes) { if (!is_fixed_size) block.getByPosition(index).column->countSerializeByteSize(wd.row_sizes); @@ -129,7 +150,11 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( container.hashes.reserve(wd.partition_row_count[i]); wd.partition_row_sizes[i] = 0; - wd.row_count += wd.partition_row_count[i]; + if (i != JOIN_BUILD_PARTITION_COUNT) + { + // Do not add the count of null rows + wd.row_count += wd.partition_row_count[i]; + } } } @@ -179,7 +204,7 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( key_getter.serializeJoinKey(key, ptr); ptr += key_getter.getJoinKeyByteSize(key); } - for (const auto & [index, _] : row_layout.other_required_column_indexes) + for (const auto & [index, _] : row_layout.other_column_indexes) { if constexpr (has_null_map && !need_record_null_rows) block.getByPosition(index).column->serializeToPos(wd.row_ptrs, start, end - start, true); @@ -204,7 +229,8 @@ void insertBlockToRowContainersType( ConstNullMapPtr null_map, const HashJoinRowLayout & row_layout, std::vector> & multi_row_containers, - JoinBuildWorkerData & worker_data) + JoinBuildWorkerData & worker_data, + bool check_lm_row_size) { #define CALL(has_null_map, need_record_null_rows) \ insertBlockToRowContainersTypeImpl( \ @@ -214,7 +240,8 @@ void insertBlockToRowContainersType( null_map, \ row_layout, \ multi_row_containers, \ - worker_data); + worker_data, \ + check_lm_row_size); if (null_map) { @@ -243,7 +270,8 @@ void insertBlockToRowContainers( ConstNullMapPtr null_map, const HashJoinRowLayout & row_layout, std::vector> & multi_row_containers, - JoinBuildWorkerData & worker_data) + JoinBuildWorkerData & worker_data, + bool check_lm_row_size) { switch (method) { @@ -258,7 +286,8 @@ void insertBlockToRowContainers( null_map, \ row_layout, \ multi_row_containers, \ - worker_data); \ + worker_data, \ + check_lm_row_size); \ break; APPLY_FOR_HASH_JOIN_VARIANTS(M) #undef M diff --git a/dbms/src/Interpreters/JoinV2/HashJoinBuild.h b/dbms/src/Interpreters/JoinV2/HashJoinBuild.h index 43d4fa1040f..906ca0841a6 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinBuild.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinBuild.h @@ -43,6 +43,7 @@ inline size_t getJoinBuildPartitionNum(HashValueType hash) struct alignas(CPU_CACHE_LINE_SIZE) JoinBuildWorkerData { std::unique_ptr> key_getter; + /// Count of not-null rows size_t row_count = 0; RowPtr null_rows_list_head = nullptr; @@ -66,6 +67,10 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinBuildWorkerData size_t all_size = 0; bool enable_tagged_pointer = true; + + /// Used for checking if late materialization will be enabled. + size_t lm_row_size = 0; + size_t lm_row_count = 0; }; void insertBlockToRowContainers( @@ -77,7 +82,8 @@ void insertBlockToRowContainers( ConstNullMapPtr null_map, const HashJoinRowLayout & row_layout, std::vector> & multi_row_containers, - JoinBuildWorkerData & worker_data); + JoinBuildWorkerData & worker_data, + bool check_lm_row_size); } // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/HashJoinKey.cpp b/dbms/src/Interpreters/JoinV2/HashJoinKey.cpp index ee880482502..0db00123a35 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinKey.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinKey.cpp @@ -103,7 +103,7 @@ void resetHashJoinKeyGetter( case HashJoinKeyMethod::METHOD: \ using KeyGetter##METHOD = typename HashJoinKeyGetterForType::Type; \ static_cast(key_getter.get()) \ - ->reset(key_columns, row_layout.raw_required_key_column_indexes.size()); \ + ->reset(key_columns, row_layout.raw_key_column_indexes.size()); \ break; APPLY_FOR_HASH_JOIN_VARIANTS(M) #undef M diff --git a/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp b/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp index 680701eabe1..de93fae6824 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp @@ -135,7 +135,7 @@ bool HashJoinPointerTable::buildImpl( { UInt16 tag = (hash & ROW_PTR_TAG_MASK) | getRowPtrTag(old_head); pointer_table[bucket].fetch_or( - static_cast(tag) << (64 - ROW_PTR_TAG_BITS), + static_cast(tag) << ROW_PTR_TAG_SHIFT, std::memory_order_relaxed); old_head = removeRowPtrTag(old_head); } diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 5f2f5277a45..fc5a4f010ed 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -13,15 +13,18 @@ // limitations under the License. #include +#include +#include #include #include #include #include +#include #include #include +#include -#include "Columns/IColumn.h" -#include "Parsers/ASTTablesInSelectQuery.h" +#include #ifdef TIFLASH_ENABLE_AVX_SUPPORT ASSERT_USE_AVX2_COMPILE_FLAG @@ -32,9 +35,18 @@ namespace DB using enum ASTTableJoin::Kind; -bool JoinProbeContext::isCurrentProbeFinished() const +bool JoinProbeContext::isProbeFinished() const { - return start_row_idx >= rows && prefetch_active_states == 0 && rows_is_matched.empty(); + return start_row_idx >= rows + // For prefetching + && prefetch_active_states == 0; +} + +bool JoinProbeContext::isAllFinished() const +{ + return isProbeFinished() + // For left outer with other conditions + && rows_not_matched.empty(); } void JoinProbeContext::resetBlock(Block & block_) @@ -45,7 +57,6 @@ void JoinProbeContext::resetBlock(Block & block_) start_row_idx = 0; current_row_ptr = nullptr; current_row_is_matched = false; - rows_is_matched.clear(); prefetch_active_states = 0; @@ -71,7 +82,12 @@ void JoinProbeContext::prepareForHashProbe( return; key_columns = extractAndMaterializeKeyColumns(block, materialized_columns, key_names); - /// Some useless columns maybe key columns so they must be removed after extracting key columns. + /// Keys with NULL value in any column won't join to anything. + extractNestedColumnsAndNullMap(key_columns, null_map_holder, null_map); + /// reuse null_map to record the filtered rows, the rows contains NULL or does not + /// match the join filter won't join to anything + recordFilteredRows(block, filter_column, null_map_holder, null_map); + /// Some useless columns maybe key columns and filter column so they must be removed after extracting key columns and filter column. for (size_t pos = 0; pos < block.columns();) { if (!probe_output_name_set.contains(block.getByPosition(pos).name)) @@ -80,12 +96,6 @@ void JoinProbeContext::prepareForHashProbe( ++pos; } - /// Keys with NULL value in any column won't join to anything. - extractNestedColumnsAndNullMap(key_columns, null_map_holder, null_map); - /// reuse null_map to record the filtered rows, the rows contains NULL or does not - /// match the join filter won't join to anything - recordFilteredRows(block, filter_column, null_map_holder, null_map); - if unlikely (!key_getter) key_getter = createHashJoinKeyGetter(method, collators); resetHashJoinKeyGetter(method, key_getter, key_columns, row_layout); @@ -102,306 +112,384 @@ void JoinProbeContext::prepareForHashProbe( assertBlocksHaveEqualStructure(block, sample_block_pruned, "Join Probe"); - if ((kind == LeftOuter || kind == Semi || kind == Anti) && has_other_condition) + if (kind == LeftOuter && has_other_condition) { - rows_is_matched.clear(); - rows_is_matched.resize_fill_zero(block.rows()); + rows_not_matched.clear(); + rows_not_matched.resize_fill(block.rows(), 1); + not_matched_offsets_idx = -1; + not_matched_offsets.clear(); } is_prepared = true; } -#define PREFETCH_READ(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */) - -/// The implemtation of prefetching in join probe process is inspired by a paper named -/// `Asynchronous Memory Access Chaining` in vldb-15. -/// Ref: https://www.vldb.org/pvldb/vol9/p252-kocberber.pdf -enum class ProbePrefetchStage : UInt8 +template +struct ProbeAdder { - None, - FindHeader, - FindNext, -}; + static constexpr bool need_not_matched = false; -template -struct ProbePrefetchState -{ - using KeyGetterType = typename KeyGetter::Type; - using KeyType = typename KeyGetterType::KeyType; - using HashValueType = typename KeyGetter::HashValueType; + static bool ALWAYS_INLINE addMatched( + JoinProbeBlockHelper & helper, + JoinProbeContext &, + JoinProbeWorkerData & wd, + MutableColumns & added_columns, + size_t idx, + size_t & current_offset, + RowPtr row_ptr, + size_t ptr_offset) + { + ++current_offset; + wd.selective_offsets.push_back(idx); + helper.insertRowToBatch(wd, added_columns, row_ptr + ptr_offset); + return current_offset >= helper.settings.max_block_size; + } - ProbePrefetchStage stage = ProbePrefetchStage::None; - bool is_matched = false; - UInt16 hash_tag = 0; - HashValueType hash = 0; - size_t index = 0; - union + static bool ALWAYS_INLINE + addNotMatched(JoinProbeBlockHelper &, JoinProbeContext &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &) { - RowPtr ptr = nullptr; - std::atomic * pointer_ptr; - }; - KeyType key{}; + return false; + } + + static void flush(JoinProbeBlockHelper & helper, JoinProbeWorkerData & wd, MutableColumns & added_columns) + { + helper.flushBatchIfNecessary(wd, added_columns); + helper.fillNullMapWithZero(added_columns); + } }; -template -class JoinProbeBlockHelper +template +struct ProbeAdder { -public: - using KeyGetterType = typename KeyGetter::Type; - using KeyType = typename KeyGetterType::KeyType; - using Hash = typename KeyGetter::Hash; - using HashValueType = typename KeyGetter::HashValueType; + static constexpr bool need_not_matched = !has_other_condition; - JoinProbeBlockHelper( - JoinProbeContext & context, + static bool ALWAYS_INLINE addMatched( + JoinProbeBlockHelper & helper, + JoinProbeContext &, JoinProbeWorkerData & wd, - HashJoinKeyMethod method, - ASTTableJoin::Kind kind, - const JoinNonEqualConditions & non_equal_conditions, - const HashJoinSettings & settings, - const HashJoinPointerTable & pointer_table, - const HashJoinRowLayout & row_layout, MutableColumns & added_columns, - size_t added_rows) - : context(context) - , wd(wd) - , method(method) - , kind(kind) - , non_equal_conditions(non_equal_conditions) - , settings(settings) - , pointer_table(pointer_table) - , row_layout(row_layout) - , added_columns(added_columns) - , added_rows(added_rows) - { - wd.insert_batch.clear(); - wd.insert_batch.reserve(settings.probe_insert_batch_size); - wd.selective_offsets.clear(); - wd.selective_offsets.reserve(settings.max_block_size); - if constexpr (late_materialization) - { - wd.row_ptrs_for_lm.clear(); - wd.row_ptrs_for_lm.reserve(settings.max_block_size); - } + size_t idx, + size_t & current_offset, + RowPtr row_ptr, + size_t ptr_offset) + { + ++current_offset; + wd.selective_offsets.push_back(idx); + helper.insertRowToBatch(wd, added_columns, row_ptr + ptr_offset); + return current_offset >= helper.settings.max_block_size; + } - if (pointer_table.enableProbePrefetch() && !context.prefetch_states) + static bool ALWAYS_INLINE addNotMatched( + JoinProbeBlockHelper & helper, + JoinProbeContext &, + JoinProbeWorkerData & wd, + MutableColumns &, + size_t idx, + size_t & current_offset) + { + if constexpr (!has_other_condition) { - context.prefetch_states = decltype(context.prefetch_states)( - static_cast(new ProbePrefetchState[settings.probe_prefetch_step]), - [](void * ptr) { delete[] static_cast *>(ptr); }); + ++current_offset; + wd.not_matched_selective_offsets.push_back(idx); + return current_offset >= helper.settings.max_block_size; } + return false; + } - if constexpr (late_materialization) - { - RUNTIME_CHECK( - added_columns.size() - == row_layout.raw_required_key_column_indexes.size() - + row_layout.other_required_count_for_other_condition); - } - else + static void flush(JoinProbeBlockHelper & helper, JoinProbeWorkerData & wd, MutableColumns & added_columns) + { + helper.flushBatchIfNecessary(wd, added_columns); + helper.fillNullMapWithZero(added_columns); + + if constexpr (!has_other_condition) { - RUNTIME_CHECK( - added_columns.size() - == row_layout.raw_required_key_column_indexes.size() + row_layout.other_required_column_indexes.size()); + if (!wd.not_matched_selective_offsets.empty()) + { + size_t null_size = wd.not_matched_selective_offsets.size(); + for (auto & column : added_columns) + column->insertManyDefaults(null_size); + wd.selective_offsets.insert( + wd.not_matched_selective_offsets.begin(), + wd.not_matched_selective_offsets.end()); + wd.not_matched_selective_offsets.clear(); + } } } +}; - void joinProbeBlockImpl(); +JoinProbeBlockHelper::JoinProbeBlockHelper(const HashJoin * join, bool late_materialization) + : join(join) + , settings(join->settings) + , pointer_table(join->pointer_table) + , row_layout(join->row_layout) +{ +#define CALL3(KeyGetter, JoinType, has_other_condition, late_materialization, tagged_pointer) \ + { \ + func_ptr_has_null \ + = &JoinProbeBlockHelper:: \ + probeImpl; \ + func_ptr_no_null \ + = &JoinProbeBlockHelper:: \ + probeImpl; \ + } -private: - void NO_INLINE joinProbeBlockInner(); - void NO_INLINE joinProbeBlockInnerPrefetch(); +#define CALL2(KeyGetter, JoinType, has_other_condition, late_materialization) \ + { \ + if (pointer_table.enableTaggedPointer()) \ + CALL3(KeyGetter, JoinType, has_other_condition, late_materialization, true) \ + else \ + CALL3(KeyGetter, JoinType, has_other_condition, late_materialization, false) \ + } - void NO_INLINE joinProbeBlockLeftOuter(); - void NO_INLINE joinProbeBlockLeftOuterPrefetch(); +#define CALL1(KeyGetter, JoinType) \ + { \ + if (join->has_other_condition) \ + { \ + if (late_materialization) \ + CALL2(KeyGetter, JoinType, true, true) \ + else \ + CALL2(KeyGetter, JoinType, true, false) \ + } \ + else \ + CALL2(KeyGetter, JoinType, false, false) \ + } - void NO_INLINE joinProbeBlockSemi(); - void NO_INLINE joinProbeBlockSemiPrefetch(); +#define CALL(KeyGetter) \ + { \ + auto kind = join->kind; \ + /*bool has_other_condition = join->has_other_condition;*/ \ + if (kind == Inner) \ + CALL1(KeyGetter, Inner) \ + else if (kind == LeftOuter) \ + CALL1(KeyGetter, LeftOuter) \ + /*else if (kind == Semi && !has_other_condition) \ + CALL2(KeyGetter, Semi, false, false) \ + else if (kind == Anti && !has_other_condition) \ + CALL2(KeyGetter, Anti, false, false)*/ \ + else \ + throw Exception( \ + fmt::format("Logical error: unknown combination of JOIN {}", magic_enum::enum_name(join->kind)), \ + ErrorCodes::LOGICAL_ERROR); \ + } - void NO_INLINE joinProbeBlockAnti(); - void NO_INLINE joinProbeBlockAntiPrefetch(); + switch (join->method) + { +#define M(METHOD) \ + case HashJoinKeyMethod::METHOD: \ + using KeyGetterType##METHOD = HashJoinKeyGetterForType; \ + CALL(KeyGetterType##METHOD); \ + break; + APPLY_FOR_HASH_JOIN_VARIANTS(M) +#undef M + + default: + throw Exception( + fmt::format("Unknown JOIN keys variant {}.", magic_enum::enum_name(join->method)), + ErrorCodes::UNKNOWN_SET_DATA_VARIANT); + } +#undef CALL +#undef CALL1 +#undef CALL2 +#undef CALL3 +} - void NO_INLINE joinProbeBlockRightOuter(); - void NO_INLINE joinProbeBlockRightOuterPrefetch(); +Block JoinProbeBlockHelper::probe(JoinProbeContext & context, JoinProbeWorkerData & wd) +{ + if (context.null_map) + return (this->*func_ptr_has_null)(context, wd); + else + return (this->*func_ptr_no_null)(context, wd); +} - void NO_INLINE joinProbeBlockRightSemi(); - void NO_INLINE joinProbeBlockRightSemiPrefetch(); +JOIN_PROBE_TEMPLATE +Block JoinProbeBlockHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData & wd) +{ + static_assert(has_other_condition || !late_materialization); - void NO_INLINE joinProbeBlockRightAnti(); - void NO_INLINE joinProbeBlockRightAntiPrefetch(); + if unlikely (context.rows == 0) + return join->output_block_after_finalize; - bool ALWAYS_INLINE joinKeyIsEqual( - KeyGetterType & key_getter, - const KeyType & key1, - const KeyType & key2, - HashValueType hash1, - RowPtr row_ptr) const + if constexpr (kind == LeftOuter && has_other_condition) { - if constexpr (KeyGetterType::joinKeyCompareHashFirst()) - { - auto hash2 = unalignedLoad(row_ptr + sizeof(RowPtr)); - if (hash1 != hash2) - return false; - } - return key_getter.joinKeyIsEqual(key1, key2); + if (context.isProbeFinished()) + return fillNotMatchedRowsForLeftOuter(context, wd); } - void ALWAYS_INLINE insertRowToBatch(KeyGetterType & key_getter, RowPtr row_ptr, const KeyType & key) const + wd.insert_batch.clear(); + wd.insert_batch.reserve(settings.probe_insert_batch_size); + wd.selective_offsets.clear(); + wd.selective_offsets.reserve(settings.max_block_size); + if constexpr (kind == LeftOuter && !has_other_condition) { - wd.insert_batch.push_back(row_ptr + key_getter.getRequiredKeyOffset(key)); - flushBatchIfNecessary(); + wd.not_matched_selective_offsets.clear(); + wd.not_matched_selective_offsets.reserve(settings.max_block_size); + } + if constexpr (late_materialization) + { + wd.row_ptrs_for_lm.clear(); + wd.row_ptrs_for_lm.reserve(settings.max_block_size); } - template - void ALWAYS_INLINE flushBatchIfNecessary() const + size_t left_columns = join->left_sample_block_pruned.columns(); + size_t right_columns = join->right_sample_block_pruned.columns(); + if (!wd.result_block) { - if constexpr (!force) + RUNTIME_CHECK(left_columns + right_columns == join->all_sample_block_pruned.columns()); + for (size_t i = 0; i < left_columns + right_columns; ++i) { - if likely (wd.insert_batch.size() < settings.probe_insert_batch_size) - return; + ColumnWithTypeAndName new_column = join->all_sample_block_pruned.safeGetByPosition(i).cloneEmpty(); + new_column.column->assumeMutable()->reserveAlign(settings.max_block_size, FULL_VECTOR_SIZE_AVX2); + wd.result_block.insert(std::move(new_column)); } - if constexpr (late_materialization) - { - size_t idx = 0; - for (auto [_, is_nullable] : row_layout.raw_required_key_column_indexes) - { - IColumn * column = added_columns[idx].get(); - if (has_null_map && is_nullable) - column = &static_cast(*added_columns[idx]).getNestedColumn(); - column->deserializeAndInsertFromPos(wd.insert_batch, true); - ++idx; - } - for (size_t i = 0; i < row_layout.other_required_count_for_other_condition; ++i) - added_columns[idx++]->deserializeAndInsertFromPos(wd.insert_batch, true); + } - wd.row_ptrs_for_lm.insert(wd.insert_batch.begin(), wd.insert_batch.end()); - } - else + MutableColumns added_columns; + if constexpr (late_materialization) + { + for (auto [column_index, _] : row_layout.raw_key_column_indexes) + added_columns.emplace_back( + wd.result_block.safeGetByPosition(left_columns + column_index).column->assumeMutable()); + for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) { - for (auto [column_index, is_nullable] : row_layout.raw_required_key_column_indexes) - { - IColumn * column = added_columns[column_index].get(); - if (has_null_map && is_nullable) - column = &static_cast(*added_columns[column_index]).getNestedColumn(); - column->deserializeAndInsertFromPos(wd.insert_batch, true); - } - for (auto [column_index, _] : row_layout.other_required_column_indexes) - added_columns[column_index]->deserializeAndInsertFromPos(wd.insert_batch, true); + size_t column_index = row_layout.other_column_indexes[i].first; + added_columns.emplace_back( + wd.result_block.safeGetByPosition(left_columns + column_index).column->assumeMutable()); } + } + else + { + added_columns.resize(right_columns); + for (size_t i = 0; i < right_columns; ++i) + added_columns[i] = wd.result_block.safeGetByPosition(left_columns + i).column->assumeMutable(); + } - if constexpr (force) + Stopwatch watch; + if (pointer_table.enableProbePrefetch()) + probeFillColumnsPrefetch< + KeyGetter, + kind, + has_null_map, + has_other_condition, + late_materialization, + tagged_pointer>(context, wd, added_columns); + else + probeFillColumns( + context, + wd, + added_columns); + + wd.probe_hash_table_time += watch.elapsedFromLastTime(); + + if constexpr (late_materialization) + { + size_t idx = 0; + for (auto [column_index, _] : row_layout.raw_key_column_indexes) + wd.result_block.safeGetByPosition(left_columns + column_index).column = std::move(added_columns[idx++]); + for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) { - if constexpr (late_materialization) - { - size_t idx = 0; - for (auto [_, is_nullable] : row_layout.raw_required_key_column_indexes) - { - IColumn * column = added_columns[idx].get(); - if (has_null_map && is_nullable) - column = &static_cast(*added_columns[idx]).getNestedColumn(); - column->flushNTAlignBuffer(); - ++idx; - } - for (size_t i = 0; i < row_layout.other_required_count_for_other_condition; ++i) - added_columns[idx++]->flushNTAlignBuffer(); - } - else - { - for (auto [column_index, is_nullable] : row_layout.raw_required_key_column_indexes) - { - IColumn * column = added_columns[column_index].get(); - if (has_null_map && is_nullable) - column = &static_cast(*added_columns[column_index]).getNestedColumn(); - column->flushNTAlignBuffer(); - } - for (auto [column_index, _] : row_layout.other_required_column_indexes) - added_columns[column_index]->flushNTAlignBuffer(); - } + size_t column_index = row_layout.other_column_indexes[i].first; + wd.result_block.safeGetByPosition(left_columns + column_index).column = std::move(added_columns[idx++]); } - - wd.insert_batch.clear(); } + else + { + for (size_t i = 0; i < right_columns; ++i) + wd.result_block.safeGetByPosition(left_columns + i).column = std::move(added_columns[i]); + } + + if (wd.selective_offsets.empty()) + return join->output_block_after_finalize; - void ALWAYS_INLINE fillNullMapWithZero(size_t size) const + if constexpr (has_other_condition) { - if constexpr (has_null_map) + // Always using late materialization for left side + for (size_t i = 0; i < left_columns; ++i) { - for (auto [column_index, is_nullable] : row_layout.raw_required_key_column_indexes) - { - if (is_nullable) - { - auto & null_map_vec - = static_cast(*added_columns[column_index]).getNullMapColumn().getData(); - null_map_vec.resize_fill_zero(null_map_vec.size() + size); - } - } + if (!join->left_required_flag_for_other_condition[i]) + continue; + wd.result_block.safeGetByPosition(i).column->assumeMutable()->insertSelectiveFrom( + *context.block.safeGetByPosition(i).column.get(), + wd.selective_offsets); + } + } + else + { + for (size_t i = 0; i < left_columns; ++i) + { + wd.result_block.safeGetByPosition(i).column->assumeMutable()->insertSelectiveFrom( + *context.block.safeGetByPosition(i).column.get(), + wd.selective_offsets); } } -private: - JoinProbeContext & context; - JoinProbeWorkerData & wd; - const HashJoinKeyMethod method; - const ASTTableJoin::Kind kind; - const JoinNonEqualConditions & non_equal_conditions; - const HashJoinSettings & settings; - const HashJoinPointerTable & pointer_table; - const HashJoinRowLayout & row_layout; - MutableColumns & added_columns; - size_t added_rows; -}; + wd.replicate_time += watch.elapsedFromLastTime(); -template -void JoinProbeBlockHelper::joinProbeBlockImpl() -{ -#define CALL(JoinType) \ - { \ - if (pointer_table.enableProbePrefetch()) \ - joinProbeBlock##JoinType##Prefetch(); \ - else \ - joinProbeBlock##JoinType(); \ - } - - if (kind == Inner) - CALL(Inner) - else if (kind == LeftOuter) - CALL(LeftOuter) - else if (kind == Semi) - CALL(Semi) - else if (kind == Anti) - CALL(Anti) - else if (kind == RightOuter) - CALL(RightOuter) - else if (kind == RightSemi) - CALL(RightSemi) - else if (kind == RightAnti) - CALL(RightAnti) - else - throw Exception("Logical error: unknown combination of JOIN", ErrorCodes::LOGICAL_ERROR); + if constexpr (has_other_condition) + { + auto res_block = handleOtherConditions(context, wd, kind, late_materialization); + wd.other_condition_time += watch.elapsedFromLastTime(); + return res_block; + } -#undef CALL2 -#undef CALL + if (wd.result_block.rows() >= settings.max_block_size) + { + auto res_block = join->removeUselessColumnForOutput(wd.result_block); + wd.result_block = {}; + return res_block; + } + return join->output_block_after_finalize; } -template -void NO_INLINE -JoinProbeBlockHelper::joinProbeBlockInner() +JOIN_PROBE_TEMPLATE +void JoinProbeBlockHelper::probeFillColumns( + JoinProbeContext & context, + JoinProbeWorkerData & wd, + MutableColumns & added_columns) { + using KeyGetterType = typename KeyGetter::Type; + using Hash = typename KeyGetter::Hash; + using HashValueType = typename KeyGetter::HashValueType; + constexpr bool has_null_key = has_null_map || kind == LeftOuter; + if constexpr (!has_null_key) + { + for (auto [_, is_nullable] : row_layout.raw_key_column_indexes) + RUNTIME_CHECK_MSG(!is_nullable, "has_null_key is false but a key column is nullable"); + } + using Adder = ProbeAdder; + auto & key_getter = *static_cast(context.key_getter.get()); - size_t current_offset = added_rows; - auto & selective_offsets = wd.selective_offsets; + size_t current_offset = wd.result_block.rows(); size_t idx = context.start_row_idx; RowPtr ptr = context.current_row_ptr; + bool is_matched = context.current_row_is_matched; size_t collision = 0; size_t key_offset = sizeof(RowPtr); if constexpr (KeyGetterType::joinKeyCompareHashFirst()) { key_offset += sizeof(HashValueType); } + +#define NOT_MATCHED(not_matched) \ + if constexpr (Adder::need_not_matched) \ + { \ + assert(ptr == nullptr); \ + if (not_matched) \ + { \ + bool is_end = Adder::addNotMatched(*this, context, wd, added_columns, idx, current_offset); \ + if unlikely (is_end) \ + { \ + ++idx; \ + break; \ + } \ + } \ + } + for (; idx < context.rows; ++idx) { if (has_null_map && (*context.null_map)[idx]) + { + NOT_MATCHED(true) continue; + } const auto & key = key_getter.getJoinKey(idx); auto hash = static_cast(Hash()(key)); @@ -410,17 +498,23 @@ JoinProbeBlockHelper= settings.max_block_size) + if constexpr (Adder::need_not_matched) + is_matched = true; + + bool is_end = Adder::addMatched( + *this, + context, + wd, + added_columns, + idx, + current_offset, + ptr, + key_offset + key_getter.getRequiredKeyOffset(key2)); + + if unlikely (is_end) break; } - ptr = HashJoinRowLayout::getNextRowPtr(ptr); + ptr = getNextRowPtr(ptr); if (ptr == nullptr) break; } if unlikely (ptr != nullptr) { - ptr = HashJoinRowLayout::getNextRowPtr(ptr); + ptr = getNextRowPtr(ptr); if (ptr == nullptr) ++idx; break; } + if constexpr (Adder::need_not_matched) + { + NOT_MATCHED(!is_matched) + } } - flushBatchIfNecessary(); - fillNullMapWithZero(current_offset - added_rows); + + Adder::flush(*this, wd, added_columns); context.start_row_idx = idx; context.current_row_ptr = ptr; + context.current_row_is_matched = is_matched; wd.collision += collision; + +#undef NOT_MATCHED } -template -void NO_INLINE -JoinProbeBlockHelper::joinProbeBlockInnerPrefetch() +#define PREFETCH_READ(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */) + +JOIN_PROBE_TEMPLATE +void JoinProbeBlockHelper::probeFillColumnsPrefetch( + JoinProbeContext & context, + JoinProbeWorkerData & wd, + MutableColumns & added_columns) { + using KeyGetterType = typename KeyGetter::Type; + using Hash = typename KeyGetter::Hash; + using HashValueType = typename KeyGetter::HashValueType; + constexpr bool has_null_key = has_null_map || kind == LeftOuter; + if constexpr (!has_null_key) + { + for (auto [_, is_nullable] : row_layout.raw_key_column_indexes) + RUNTIME_CHECK_MSG(!is_nullable, "has_null_key is false but a key column is nullable"); + } + using Adder = ProbeAdder; + auto & key_getter = *static_cast(context.key_getter.get()); + initPrefetchStates(context); auto * states = static_cast *>(context.prefetch_states.get()); - auto & selective_offsets = wd.selective_offsets; size_t idx = context.start_row_idx; size_t active_states = context.prefetch_active_states; size_t k = context.prefetch_iter; - size_t current_offset = added_rows; + size_t current_offset = wd.result_block.rows(); size_t collision = 0; size_t key_offset = sizeof(RowPtr); if constexpr (KeyGetterType::joinKeyCompareHashFirst()) { key_offset += sizeof(HashValueType); } + +#define NOT_MATCHED(not_matched, idx) \ + if constexpr (Adder::need_not_matched) \ + { \ + if (not_matched) \ + { \ + bool is_end = Adder::addNotMatched(*this, context, wd, added_columns, idx, current_offset); \ + if unlikely (is_end) \ + break; \ + } \ + } + const size_t probe_prefetch_step = settings.probe_prefetch_step; while (idx < context.rows || active_states > 0) { @@ -482,7 +620,7 @@ JoinProbeBlockHelperstage == ProbePrefetchStage::FindNext) { RowPtr ptr = state->ptr; - RowPtr next_ptr = HashJoinRowLayout::getNextRowPtr(ptr); + RowPtr next_ptr = getNextRowPtr(ptr); if (next_ptr) { state->ptr = next_ptr; @@ -494,10 +632,19 @@ JoinProbeBlockHelperindex); - insertRowToBatch(key_getter, ptr + key_offset, key2); - if unlikely (current_offset >= settings.max_block_size) + if constexpr (Adder::need_not_matched) + state->is_matched = true; + + bool is_end = Adder::addMatched( + *this, + context, + wd, + added_columns, + state->index, + current_offset, + ptr, + key_offset + key_getter.getRequiredKeyOffset(key2)); + if unlikely (is_end) { if (!next_ptr) { @@ -516,6 +663,8 @@ JoinProbeBlockHelperstage = ProbePrefetchStage::None; --active_states; + + NOT_MATCHED(!state->is_matched, state->index); } else if (state->stage == ProbePrefetchStage::FindHeader) { @@ -542,18 +691,37 @@ JoinProbeBlockHelperstage = ProbePrefetchStage::None; --active_states; + + NOT_MATCHED(true, state->index); } assert(state->stage == ProbePrefetchStage::None); if constexpr (has_null_map) { + bool is_end = false; while (idx < context.rows) { if (!(*context.null_map)[idx]) break; + + if constexpr (Adder::need_not_matched) + { + is_end = Adder::addNotMatched(*this, context, wd, added_columns, idx, current_offset); + if unlikely (is_end) + { + ++idx; + break; + } + } + ++idx; } + if constexpr (Adder::need_not_matched) + { + if unlikely (is_end) + break; + } } if unlikely (idx >= context.rows) @@ -569,6 +737,8 @@ JoinProbeBlockHelperpointer_ptr); state->key = key; + if constexpr (Adder::need_not_matched) + state->is_matched = false; if constexpr (tagged_pointer) state->hash_tag = hash & ROW_PTR_TAG_MASK; if constexpr (KeyGetterType::joinKeyCompareHashFirst()) @@ -580,153 +750,371 @@ JoinProbeBlockHelper(); - fillNullMapWithZero(current_offset - added_rows); + Adder::flush(*this, wd, added_columns); context.start_row_idx = idx; context.prefetch_active_states = active_states; context.prefetch_iter = k; wd.collision += collision; + +#undef NOT_MATCHED } -template -void NO_INLINE -JoinProbeBlockHelper::joinProbeBlockLeftOuter() -{} - -template -void NO_INLINE -JoinProbeBlockHelper::joinProbeBlockLeftOuterPrefetch() -{} - -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockSemi() -{} - -template -void NO_INLINE -JoinProbeBlockHelper::joinProbeBlockSemiPrefetch() -{} - -template -void NO_INLINE JoinProbeBlockHelper::joinProbeBlockAnti() -{} - -template -void NO_INLINE -JoinProbeBlockHelper::joinProbeBlockAntiPrefetch() -{} - -template -void NO_INLINE -JoinProbeBlockHelper::joinProbeBlockRightOuter() -{} - -template -void NO_INLINE -JoinProbeBlockHelper::joinProbeBlockRightOuterPrefetch() -{} - -template -void NO_INLINE -JoinProbeBlockHelper::joinProbeBlockRightSemi() -{} - -template -void NO_INLINE -JoinProbeBlockHelper::joinProbeBlockRightSemiPrefetch() -{} - -template -void NO_INLINE -JoinProbeBlockHelper::joinProbeBlockRightAnti() -{} - -template -void NO_INLINE -JoinProbeBlockHelper::joinProbeBlockRightAntiPrefetch() -{} - -void joinProbeBlock( +Block JoinProbeBlockHelper::handleOtherConditions( JoinProbeContext & context, JoinProbeWorkerData & wd, - HashJoinKeyMethod method, ASTTableJoin::Kind kind, - bool late_materialization, - const JoinNonEqualConditions & non_equal_conditions, - const HashJoinSettings & settings, - const HashJoinPointerTable & pointer_table, - const HashJoinRowLayout & row_layout, - MutableColumns & added_columns, - size_t added_rows) + bool late_materialization) { - if (context.rows == 0) - return; + const auto & left_sample_block_pruned = join->left_sample_block_pruned; + const auto & right_sample_block_pruned = join->right_sample_block_pruned; + const auto & output_block_after_finalize = join->output_block_after_finalize; + const auto & non_equal_conditions = join->non_equal_conditions; + const auto & output_column_indexes = join->output_column_indexes; + const auto & left_required_flag_for_other_condition = join->left_required_flag_for_other_condition; + + size_t left_columns = left_sample_block_pruned.columns(); + size_t right_columns = right_sample_block_pruned.columns(); + // Some columns in wd.result_block may be empty so need to create another block to execute other condition expressions + Block exec_block; + RUNTIME_CHECK(wd.result_block.columns() == left_columns + right_columns); + for (size_t i = 0; i < left_columns; ++i) + { + if (left_required_flag_for_other_condition[i]) + exec_block.insert(wd.result_block.getByPosition(i)); + } + if (late_materialization) + { + for (auto [column_index, _] : row_layout.raw_key_column_indexes) + exec_block.insert(wd.result_block.getByPosition(left_columns + column_index)); + for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) + { + size_t column_index = row_layout.other_column_indexes[i].first; + exec_block.insert(wd.result_block.getByPosition(left_columns + column_index)); + } + } + else + { + for (size_t i = 0; i < right_columns; ++i) + exec_block.insert(wd.result_block.getByPosition(left_columns + i)); + } - switch (method) - { -#define CALL(KeyGetter, has_null_map, tagged_pointer, late_materialization) \ - JoinProbeBlockHelper( \ - context, \ - wd, \ - method, \ - kind, \ - non_equal_conditions, \ - settings, \ - pointer_table, \ - row_layout, \ - added_columns, \ - added_rows) \ - .joinProbeBlockImpl(); - -#define CALL3(KeyGetter, has_null_map, tagged_pointer) \ - if (late_materialization) \ - { \ - CALL(KeyGetter, has_null_map, tagged_pointer, true); \ - } \ - else \ - { \ - CALL(KeyGetter, has_null_map, tagged_pointer, false); \ - } - -#define CALL2(KeyGetter, has_null_map) \ - if (pointer_table.enableTaggedPointer()) \ - { \ - CALL3(KeyGetter, has_null_map, true); \ - } \ - else \ - { \ - CALL3(KeyGetter, has_null_map, false); \ - } - -#define CALL1(KeyGetter) \ - if (context.null_map) \ - { \ - CALL2(KeyGetter, true); \ - } \ - else \ - { \ - CALL2(KeyGetter, false); \ + non_equal_conditions.other_cond_expr->execute(exec_block); + + size_t rows = exec_block.rows(); + // Ensure BASE_OFFSETS is accessed within bound. + // It must be true because max_block_size <= BASE_OFFSETS.size(HASH_JOIN_MAX_BLOCK_SIZE_UPPER_BOUND). + RUNTIME_CHECK_MSG( + rows <= BASE_OFFSETS.size(), + "exec_block rows {} > base_offsets size {}", + rows, + BASE_OFFSETS.size()); + + wd.filter.clear(); + mergeNullAndFilterResult(exec_block, wd.filter, non_equal_conditions.other_cond_name, false); + exec_block.clear(); + + if (kind == LeftOuter) + { + RUNTIME_CHECK(wd.selective_offsets.size() == rows); + RUNTIME_CHECK(wd.filter.size() == rows); + for (size_t i = 0; i < rows; ++i) + { + size_t idx = wd.selective_offsets[i]; + bool is_matched = wd.filter[i]; + context.rows_not_matched[idx] &= !is_matched; + } } -#define M(METHOD) \ - case HashJoinKeyMethod::METHOD: \ - using KeyGetterType##METHOD = HashJoinKeyGetterForType; \ - CALL1(KeyGetterType##METHOD); \ - break; - APPLY_FOR_HASH_JOIN_VARIANTS(M) -#undef M + join->initOutputBlock(wd.result_block_for_other_condition); + + RUNTIME_CHECK_MSG( + wd.result_block_for_other_condition.rows() < settings.max_block_size, + "result_block_for_other_condition rows {} >= max_block_size {}", + wd.result_block_for_other_condition.rows(), + settings.max_block_size); + size_t remaining_insert_size = settings.max_block_size - wd.result_block_for_other_condition.rows(); + size_t result_size = countBytesInFilter(wd.filter); + + bool filter_offsets_is_initialized = false; + auto init_filter_offsets = [&]() { + RUNTIME_CHECK(wd.filter.size() == rows); + wd.filter_offsets.clear(); + wd.filter_offsets.reserve(result_size); + filterImpl(&wd.filter[0], &wd.filter[rows], &BASE_OFFSETS[0], wd.filter_offsets); + RUNTIME_CHECK(wd.filter_offsets.size() == result_size); + filter_offsets_is_initialized = true; + }; -#undef CALL1 -#undef CALL2 -#undef CALL3 -#undef CALL + bool filter_selective_offsets_is_initialized = false; + auto init_filter_selective_offsets = [&]() { + RUNTIME_CHECK(wd.selective_offsets.size() == rows); + wd.filter_selective_offsets.clear(); + wd.filter_selective_offsets.reserve(result_size); + filterImpl(&wd.filter[0], &wd.filter[rows], &wd.selective_offsets[0], wd.filter_selective_offsets); + RUNTIME_CHECK(wd.filter_selective_offsets.size() == result_size); + filter_selective_offsets_is_initialized = true; + }; - default: - throw Exception( - fmt::format("Unknown JOIN keys variant {}.", magic_enum::enum_name(method)), - ErrorCodes::UNKNOWN_SET_DATA_VARIANT); + bool filter_row_ptrs_for_lm_is_initialized = false; + auto init_filter_row_ptrs_for_lm = [&]() { + RUNTIME_CHECK(wd.row_ptrs_for_lm.size() == rows); + wd.filter_row_ptrs_for_lm.clear(); + wd.filter_row_ptrs_for_lm.reserve(result_size); + filterImpl(&wd.filter[0], &wd.filter[rows], &wd.row_ptrs_for_lm[0], wd.filter_row_ptrs_for_lm); + RUNTIME_CHECK(wd.filter_row_ptrs_for_lm.size() == result_size); + filter_row_ptrs_for_lm_is_initialized = true; + }; + + auto fill_matched = [&](size_t start, size_t length) { + if (length == 0) + return; + + if (late_materialization) + { + for (auto [column_index, _] : row_layout.raw_key_column_indexes) + { + auto output_index = output_column_indexes.at(left_columns + column_index); + if (output_index < 0) + continue; + if unlikely (!filter_offsets_is_initialized) + init_filter_offsets(); + auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(output_index); + auto & src_column = wd.result_block.safeGetByPosition(left_columns + column_index); + des_column.column->assumeMutable() + ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); + } + for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) + { + size_t column_index = row_layout.other_column_indexes[i].first; + auto output_index = output_column_indexes.at(left_columns + column_index); + if (output_index < 0) + continue; + if unlikely (!filter_offsets_is_initialized) + init_filter_offsets(); + auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(output_index); + auto & src_column = wd.result_block.safeGetByPosition(left_columns + column_index); + des_column.column->assumeMutable() + ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); + } + + if (!filter_row_ptrs_for_lm_is_initialized) + init_filter_row_ptrs_for_lm(); + + if (row_layout.other_column_count_for_other_condition < row_layout.other_column_indexes.size()) + { + auto other_column_indexes_start = row_layout.other_column_count_for_other_condition; + auto other_column_indexes_size = row_layout.other_column_indexes.size(); + // Sanity check: all columns after other_column_indexes_start should be included in wd.result_block_for_other_condition. + for (size_t i = other_column_indexes_start; i < other_column_indexes_size; ++i) + { + size_t column_index = row_layout.other_column_indexes[i].first; + auto output_index = output_column_indexes.at(left_columns + column_index); + RUNTIME_CHECK(output_index >= 0); + RUNTIME_CHECK(static_cast(output_index) < wd.result_block_for_other_condition.columns()); + } + constexpr size_t step = 256; + for (size_t pos = start; pos < start + length; pos += step) + { + size_t end = pos + step > start + length ? start + length : pos + step; + wd.insert_batch.clear(); + wd.insert_batch.insert(&wd.row_ptrs_for_lm[pos], &wd.row_ptrs_for_lm[end]); + for (size_t i = other_column_indexes_start; i < other_column_indexes_size; ++i) + { + size_t column_index = row_layout.other_column_indexes[i].first; + auto output_index = output_column_indexes[left_columns + column_index]; + auto & des_column = wd.result_block_for_other_condition.getByPosition(output_index); + des_column.column->assumeMutable()->deserializeAndInsertFromPos(wd.insert_batch, true); + } + } + for (size_t i = other_column_indexes_start; i < other_column_indexes_size; ++i) + { + size_t column_index = row_layout.other_column_indexes[i].first; + auto output_index = output_column_indexes[left_columns + column_index]; + auto & des_column = wd.result_block_for_other_condition.getByPosition(output_index); + des_column.column->assumeMutable()->flushNTAlignBuffer(); + } + } + } + else + { + for (size_t i = 0; i < right_columns; ++i) + { + auto output_index = output_column_indexes.at(left_columns + i); + if (output_index == -1) + continue; + if unlikely (!filter_offsets_is_initialized) + init_filter_offsets(); + auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(output_index); + auto & src_column = wd.result_block.safeGetByPosition(left_columns + i); + des_column.column->assumeMutable() + ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); + } + } + + for (size_t i = 0; i < left_columns; ++i) + { + auto output_index = output_column_indexes.at(i); + if (output_index == -1) + continue; + auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(output_index); + if (left_required_flag_for_other_condition[i]) + { + if unlikely (!filter_offsets_is_initialized && !filter_selective_offsets_is_initialized) + init_filter_selective_offsets(); + if (filter_offsets_is_initialized) + { + auto & src_column = wd.result_block.safeGetByPosition(i); + des_column.column->assumeMutable() + ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); + } + else + { + auto & src_column = context.block.safeGetByPosition(i); + des_column.column->assumeMutable()->insertSelectiveRangeFrom( + *src_column.column.get(), + wd.filter_selective_offsets, + start, + length); + } + continue; + } + if unlikely (!filter_selective_offsets_is_initialized) + init_filter_selective_offsets(); + auto & src_column = context.block.safeGetByPosition(i); + des_column.column->assumeMutable() + ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_selective_offsets, start, length); + } + }; + + SCOPE_EXIT({ + RUNTIME_CHECK(wd.result_block.columns() == left_columns + right_columns); + /// Clear the data in result_block. + for (size_t i = 0; i < left_columns + right_columns; ++i) + { + auto column = wd.result_block.getByPosition(i).column->assumeMutable(); + column->popBack(column->size()); + wd.result_block.getByPosition(i).column = std::move(column); + } + }); + + size_t length = std::min(result_size, remaining_insert_size); + fill_matched(0, length); + if (result_size >= remaining_insert_size) + { + Block res_block; + res_block.swap(wd.result_block_for_other_condition); + if (result_size > remaining_insert_size) + { + join->initOutputBlock(wd.result_block_for_other_condition); + fill_matched(remaining_insert_size, result_size - remaining_insert_size); + } + + return res_block; } + + if (kind == LeftOuter) + return fillNotMatchedRowsForLeftOuter(context, wd); + + return output_block_after_finalize; +} + +Block JoinProbeBlockHelper::fillNotMatchedRowsForLeftOuter(JoinProbeContext & context, JoinProbeWorkerData & wd) +{ + RUNTIME_CHECK(join->kind == LeftOuter); + RUNTIME_CHECK(join->has_other_condition); + RUNTIME_CHECK(context.isProbeFinished()); + if (context.not_matched_offsets_idx < 0) + { + size_t rows = context.rows; + size_t not_matched_result_size = countBytesInFilter(context.rows_not_matched); + auto & offsets = context.not_matched_offsets; + + offsets.clear(); + offsets.reserve(not_matched_result_size); + if likely (rows <= BASE_OFFSETS.size()) + { + filterImpl(&context.rows_not_matched[0], &context.rows_not_matched[rows], &BASE_OFFSETS[0], offsets); + RUNTIME_CHECK(offsets.size() == not_matched_result_size); + } + else + { + for (size_t i = 0; i < rows; ++i) + { + if (context.rows_not_matched[i]) + offsets.push_back(i); + } + } + + context.not_matched_offsets_idx = 0; + } + const auto & output_block_after_finalize = join->output_block_after_finalize; + + if (static_cast(context.not_matched_offsets_idx) >= context.not_matched_offsets.size()) + { + // JoinProbeContext::isAllFinished checks if all not matched rows have been output + // by verifying whether rows_not_matched is empty. + context.rows_not_matched.clear(); + return output_block_after_finalize; + } + + join->initOutputBlock(wd.result_block_for_other_condition); + + size_t left_columns = join->left_sample_block_pruned.columns(); + size_t right_columns = join->right_sample_block_pruned.columns(); + + if (wd.result_block_for_other_condition.rows() >= settings.max_block_size) + { + Block res_block; + res_block.swap(wd.result_block_for_other_condition); + return res_block; + } + + size_t remaining_insert_size = settings.max_block_size - wd.result_block_for_other_condition.rows(); + ; + size_t result_size = context.not_matched_offsets.size() - context.not_matched_offsets_idx; + size_t length = std::min(result_size, remaining_insert_size); + + const auto & output_column_indexes = join->output_column_indexes; + for (size_t i = 0; i < right_columns; ++i) + { + auto output_index = output_column_indexes.at(left_columns + i); + if (output_index == -1) + continue; + auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(output_index); + des_column.column->assumeMutable()->insertManyDefaults(length); + } + + for (size_t i = 0; i < left_columns; ++i) + { + auto output_index = output_column_indexes.at(i); + if (output_index == -1) + continue; + auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(output_index); + auto & src_column = context.block.safeGetByPosition(i); + des_column.column->assumeMutable()->insertSelectiveRangeFrom( + *src_column.column.get(), + context.not_matched_offsets, + context.not_matched_offsets_idx, + length); + } + context.not_matched_offsets_idx += length; + + if (static_cast(context.not_matched_offsets_idx) >= context.not_matched_offsets.size()) + { + // JoinProbeContext::isAllFinished checks if all not matched rows have been output + // by verifying whether rows_not_matched is empty. + context.rows_not_matched.clear(); + } + + if (result_size >= remaining_insert_size) + { + Block res_block; + res_block.swap(wd.result_block_for_other_condition); + return res_block; + } + + return output_block_after_finalize; } } // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 22e8a7c0296..dbdcb28e795 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -41,7 +41,10 @@ struct JoinProbeContext /// For left outer/(left outer) (anti) semi join without other conditions. bool current_row_is_matched = false; /// For left outer/(left outer) (anti) semi join with other conditions. - IColumn::Filter rows_is_matched; + IColumn::Filter rows_not_matched; + /// < 0 means not_matched_offsets is not initialized. + ssize_t not_matched_offsets_idx = -1; + IColumn::Offsets not_matched_offsets; size_t prefetch_active_states = 0; size_t prefetch_iter = 0; @@ -56,7 +59,8 @@ struct JoinProbeContext bool input_is_finished = false; - bool isCurrentProbeFinished() const; + bool isProbeFinished() const; + bool isAllFinished() const; void resetBlock(Block & block_); void prepareForHashProbe( @@ -74,6 +78,8 @@ struct JoinProbeContext struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData { IColumn::Offsets selective_offsets; + // For left outer join with no other condition + IColumn::Offsets not_matched_selective_offsets; RowPtrs row_ptrs_for_lm; RowPtrs insert_batch; @@ -97,18 +103,210 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData Block result_block_for_other_condition; }; -void joinProbeBlock( - JoinProbeContext & context, - JoinProbeWorkerData & wd, - HashJoinKeyMethod method, - ASTTableJoin::Kind kind, - bool late_materialization, - const JoinNonEqualConditions & non_equal_conditions, - const HashJoinSettings & settings, - const HashJoinPointerTable & pointer_table, - const HashJoinRowLayout & row_layout, - MutableColumns & added_columns, - size_t added_rows); +/// The implemtation of prefetching in join probe process is inspired by a paper named +/// `Asynchronous Memory Access Chaining` in vldb-15. +/// Ref: https://www.vldb.org/pvldb/vol9/p252-kocberber.pdf +enum class ProbePrefetchStage : UInt8 +{ + None, + FindHeader, + FindNext, +}; + +template +struct ProbePrefetchState +{ + using KeyGetterType = typename KeyGetter::Type; + using KeyType = typename KeyGetterType::KeyType; + using HashValueType = typename KeyGetter::HashValueType; + + ProbePrefetchStage stage = ProbePrefetchStage::None; + bool is_matched = false; + UInt16 hash_tag = 0; + HashValueType hash = 0; + size_t index = 0; + union + { + RowPtr ptr = nullptr; + std::atomic * pointer_ptr; + }; + KeyType key{}; +}; + +template +struct ProbeAdder; + +#define JOIN_PROBE_TEMPLATE \ + template < \ + typename KeyGetter, \ + ASTTableJoin::Kind kind, \ + bool has_null_map, \ + bool has_other_condition, \ + bool late_materialization, \ + bool tagged_pointer> + +class HashJoin; +class JoinProbeBlockHelper +{ +public: + JoinProbeBlockHelper(const HashJoin * join, bool late_materialization); + + Block probe(JoinProbeContext & context, JoinProbeWorkerData & wd); + +private: + JOIN_PROBE_TEMPLATE + Block probeImpl(JoinProbeContext & context, JoinProbeWorkerData & wd); + + JOIN_PROBE_TEMPLATE + void NO_INLINE + probeFillColumns(JoinProbeContext & context, JoinProbeWorkerData & wd, MutableColumns & added_columns); + JOIN_PROBE_TEMPLATE + void NO_INLINE + probeFillColumnsPrefetch(JoinProbeContext & context, JoinProbeWorkerData & wd, MutableColumns & added_columns); + template + void ALWAYS_INLINE initPrefetchStates(JoinProbeContext & context) + { + if (!context.prefetch_states) + { + context.prefetch_states = decltype(context.prefetch_states)( + static_cast(new ProbePrefetchState[settings.probe_prefetch_step]), + [](void * ptr) { delete[] static_cast *>(ptr); }); + } + } + + template + bool ALWAYS_INLINE joinKeyIsEqual( + KeyGetterType & key_getter, + const KeyType & key1, + const KeyType & key2, + HashValueType hash1, + RowPtr row_ptr) const + { + if constexpr (KeyGetterType::joinKeyCompareHashFirst()) + { + auto hash2 = unalignedLoad(row_ptr + sizeof(RowPtr)); + if (hash1 != hash2) + return false; + } + return key_getter.joinKeyIsEqual(key1, key2); + } + + template + void ALWAYS_INLINE insertRowToBatch(JoinProbeWorkerData & wd, MutableColumns & added_columns, RowPtr row_ptr) const + { + wd.insert_batch.push_back(row_ptr); + flushBatchIfNecessary(wd, added_columns); + } + + template + void ALWAYS_INLINE flushBatchIfNecessary(JoinProbeWorkerData & wd, MutableColumns & added_columns) const + { + if constexpr (!force) + { + if likely (wd.insert_batch.size() < settings.probe_insert_batch_size) + return; + } + if constexpr (late_materialization) + { + size_t idx = 0; + for (auto [_, is_nullable] : row_layout.raw_key_column_indexes) + { + IColumn * column = added_columns[idx].get(); + if (has_null_key && is_nullable) + column = &static_cast(*added_columns[idx]).getNestedColumn(); + column->deserializeAndInsertFromPos(wd.insert_batch, true); + ++idx; + } + for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) + added_columns[idx++]->deserializeAndInsertFromPos(wd.insert_batch, true); + + wd.row_ptrs_for_lm.insert(wd.insert_batch.begin(), wd.insert_batch.end()); + } + else + { + for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) + { + IColumn * column = added_columns[column_index].get(); + if (has_null_key && is_nullable) + column = &static_cast(*added_columns[column_index]).getNestedColumn(); + column->deserializeAndInsertFromPos(wd.insert_batch, true); + } + for (auto [column_index, _] : row_layout.other_column_indexes) + added_columns[column_index]->deserializeAndInsertFromPos(wd.insert_batch, true); + } + + if constexpr (force) + { + if constexpr (late_materialization) + { + size_t idx = 0; + for (auto [_, is_nullable] : row_layout.raw_key_column_indexes) + { + IColumn * column = added_columns[idx].get(); + if (has_null_key && is_nullable) + column = &static_cast(*added_columns[idx]).getNestedColumn(); + column->flushNTAlignBuffer(); + ++idx; + } + for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) + added_columns[idx++]->flushNTAlignBuffer(); + } + else + { + for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) + { + IColumn * column = added_columns[column_index].get(); + if (has_null_key && is_nullable) + column = &static_cast(*added_columns[column_index]).getNestedColumn(); + column->flushNTAlignBuffer(); + } + for (auto [column_index, _] : row_layout.other_column_indexes) + added_columns[column_index]->flushNTAlignBuffer(); + } + } + + wd.insert_batch.clear(); + } + + template + void ALWAYS_INLINE fillNullMapWithZero(MutableColumns & added_columns) const + { + if constexpr (has_null_key) + { + for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) + { + if (is_nullable) + { + auto & nullable_column = static_cast(*added_columns[column_index]); + size_t data_size = nullable_column.getNestedColumn().size(); + size_t nullmap_size = nullable_column.getNullMapColumn().size(); + RUNTIME_CHECK(nullmap_size <= data_size); + nullable_column.getNullMapColumn().getData().resize_fill_zero(data_size); + } + } + } + } + + Block handleOtherConditions( + JoinProbeContext & context, + JoinProbeWorkerData & wd, + ASTTableJoin::Kind kind, + bool late_materialization); + + Block fillNotMatchedRowsForLeftOuter(JoinProbeContext & context, JoinProbeWorkerData & wd); + +private: + template + friend struct ProbeAdder; + + using FuncType = Block (JoinProbeBlockHelper::*)(JoinProbeContext &, JoinProbeWorkerData &); + FuncType func_ptr_has_null = nullptr; + FuncType func_ptr_no_null = nullptr; + const HashJoin * join; + const HashJoinSettings & settings; + const HashJoinPointerTable & pointer_table; + const HashJoinRowLayout & row_layout; +}; } // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h b/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h index 72c88c1aaff..6cf6067d2bd 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h @@ -14,18 +14,75 @@ #pragma once +#include #include +#include #include namespace DB { -constexpr size_t ROW_ALIGN = 4; +/// Row Layout +/// 1. if required hash value comparison: +/// +/// 2. if not required hash value comparison: +/// +struct HashJoinRowLayout +{ + /// The raw join key column are the same as the original data. + /// raw_key_column_index in HashJoin::right_sample_block_pruned + is_nullable. + std::vector> raw_key_column_indexes; + /// other_column_index in HashJoin::right_sample_block_pruned + is_fixed_size + std::vector> other_column_indexes; + /// Number of columns at the beginning of `output_other_column_indexes` + /// that are used for evaluating the join other conditions. + size_t other_column_count_for_other_condition = 0; + + size_t key_column_fixed_size = 0; + size_t other_column_fixed_size = 0; +}; using RowPtr = char *; using RowPtrs = PaddedPODArray; +constexpr size_t ROW_ALIGN = 4; + +constexpr size_t ROW_PTR_TAG_BITS = 16; +constexpr size_t ROW_PTR_TAG_MASK = (1 << ROW_PTR_TAG_BITS) - 1; +constexpr size_t ROW_PTR_TAG_SHIFT = 8 * sizeof(RowPtr) - ROW_PTR_TAG_BITS; +static_assert(sizeof(RowPtr) == sizeof(uintptr_t)); +static_assert(sizeof(RowPtr) == 8); + +inline RowPtr getNextRowPtr(const RowPtr ptr) +{ + return unalignedLoad(ptr); +} + +inline UInt16 getRowPtrTag(RowPtr ptr) +{ + auto address = reinterpret_cast(ptr); + return address >> ROW_PTR_TAG_SHIFT; +} + +inline bool isRowPtrTagZero(RowPtr ptr) +{ + return getRowPtrTag(ptr) == 0; +} + +inline RowPtr removeRowPtrTag(RowPtr ptr) +{ + auto address = reinterpret_cast(ptr); + address &= (1ULL << ROW_PTR_TAG_SHIFT) - 1; + return reinterpret_cast(address); +} + +inline bool containOtherTag(RowPtr ptr, UInt16 other_tag) +{ + UInt16 tag = getRowPtrTag(ptr); + return (tag | other_tag) == tag; +} + struct RowContainer { PaddedPODArray data; @@ -71,63 +128,4 @@ struct alignas(CPU_CACHE_LINE_SIZE) MultipleRowContainer } }; -/// Row Layout -/// 1. Not-null join key row if required hash value comparison: -/// -/// 2. Not-null join key row if not required hash value comparison: -/// -/// 3. Null join key row(For right anti/outer join): -struct HashJoinRowLayout -{ - /// The raw join key are the same as the original data. - /// raw_required_key_column_index + is_nullable - std::vector> raw_required_key_column_indexes; - /// other_required_column_index + is_fixed_size - std::vector> other_required_column_indexes; - /// Number of columns at the beginning of `other_required_column_indexes` - /// that are used for evaluating the join other condition. - size_t other_required_count_for_other_condition = 0; - - size_t key_column_fixed_size = 0; - size_t other_column_fixed_size = 0; - - static RowPtr getNextRowPtr(const RowPtr ptr) { return unalignedLoad(ptr); } -}; - -constexpr size_t ROW_PTR_TAG_BITS = 16; -constexpr size_t ROW_PTR_TAG_MASK = (1 << ROW_PTR_TAG_BITS) - 1; - -inline UInt16 getRowPtrTag(RowPtr ptr) -{ - static_assert(sizeof(RowPtr) == 8); - auto address = reinterpret_cast(ptr); - return address >> (64 - ROW_PTR_TAG_BITS); -} - -inline bool isRowPtrTagZero(RowPtr ptr) -{ - return getRowPtrTag(ptr) == 0; -} - -inline RowPtr removeRowPtrTag(RowPtr ptr) -{ - auto address = reinterpret_cast(ptr); - address &= (1ULL << (64 - ROW_PTR_TAG_BITS)) - 1; - return reinterpret_cast(address); -} - -inline RowPtr addRowPtrTag(RowPtr ptr, UInt16 tag) -{ - static_assert(sizeof(uintptr_t) == 8); - auto address = reinterpret_cast(ptr); - address |= static_cast(tag) << (64 - ROW_PTR_TAG_BITS); - return reinterpret_cast(address); -} - -inline bool containOtherTag(RowPtr ptr, UInt16 other_tag) -{ - UInt16 tag = getRowPtrTag(ptr); - return (tag | other_tag) == tag; -} - } // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/HashJoinSettings.h b/dbms/src/Interpreters/JoinV2/HashJoinSettings.h index f0c60b8d9e1..0154a66016f 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinSettings.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinSettings.h @@ -14,15 +14,27 @@ #pragma once +#include #include namespace DB { +/// The max_block_size upper bound of hash join. +#define HASH_JOIN_MAX_BLOCK_SIZE_UPPER_BOUND (65536 * 2) + +const IColumn::Offsets BASE_OFFSETS = [] { + IColumn::Offsets offsets(HASH_JOIN_MAX_BLOCK_SIZE_UPPER_BOUND); + std::iota(offsets.begin(), offsets.end(), 0ULL); + return offsets; +}(); + struct HashJoinSettings { explicit HashJoinSettings(const Settings & settings) - : max_block_size(std::min(settings.max_block_size, settings.join_v2_max_block_size)) + : max_block_size(std::min( + HASH_JOIN_MAX_BLOCK_SIZE_UPPER_BOUND, + std::min(settings.max_block_size, settings.join_v2_max_block_size))) , probe_enable_prefetch_threshold(settings.join_v2_probe_enable_prefetch_threshold) , probe_prefetch_step(settings.join_v2_probe_prefetch_step) , probe_insert_batch_size(settings.join_v2_probe_insert_batch_size) diff --git a/dbms/src/Interpreters/LogicalExpressionsOptimizer.cpp b/dbms/src/Interpreters/LogicalExpressionsOptimizer.cpp index 460d73b3a33..e50473ac4e8 100644 --- a/dbms/src/Interpreters/LogicalExpressionsOptimizer.cpp +++ b/dbms/src/Interpreters/LogicalExpressionsOptimizer.cpp @@ -118,30 +118,30 @@ void LogicalExpressionsOptimizer::collectDisjunctiveEqualityChains() while (!to_visit.empty()) { auto edge = to_visit.back(); - auto from_node = edge.first; - auto to_node = edge.second; + auto * from_node = edge.first; + auto * to_node = edge.second; to_visit.pop_back(); bool found_chain = false; - auto function = typeid_cast(to_node); + auto * function = typeid_cast(to_node); if ((function != nullptr) && (function->name == "or") && (function->children.size() == 1)) { - auto expression_list = typeid_cast(&*(function->children[0])); + auto * expression_list = typeid_cast(&*(function->children[0])); if (expression_list != nullptr) { /// The chain of elements of the OR expression. for (auto & child : expression_list->children) { - auto equals = typeid_cast(&*child); + auto * equals = typeid_cast(&*child); if ((equals != nullptr) && (equals->name == "equals") && (equals->children.size() == 1)) { - auto equals_expression_list = typeid_cast(&*(equals->children[0])); + auto * equals_expression_list = typeid_cast(&*(equals->children[0])); if ((equals_expression_list != nullptr) && (equals_expression_list->children.size() == 2)) { /// Equality expr = xN. - auto literal = typeid_cast(&*(equals_expression_list->children[1])); + auto * literal = typeid_cast(&*(equals_expression_list->children[1])); if (literal != nullptr) { auto expr_lhs = equals_expression_list->children[0]->getTreeHash(); @@ -203,7 +203,7 @@ namespace { inline ASTs & getFunctionOperands(ASTFunction * or_function) { - auto expression_list = static_cast(&*(or_function->children[0])); + auto * expression_list = static_cast(&*(or_function->children[0])); return expression_list->children; } @@ -220,11 +220,11 @@ bool LogicalExpressionsOptimizer::mayOptimizeDisjunctiveEqualityChain(const Disj /// We check that the right-hand sides of all equalities have the same type. auto & first_operands = getFunctionOperands(equality_functions[0]); - auto first_literal = static_cast(&*first_operands[1]); + auto * first_literal = static_cast(&*first_operands[1]); for (size_t i = 1; i < equality_functions.size(); ++i) { auto & operands = getFunctionOperands(equality_functions[i]); - auto literal = static_cast(&*operands[1]); + auto * literal = static_cast(&*operands[1]); if (literal->value.getType() != first_literal->value.getType()) return false; @@ -242,7 +242,7 @@ void LogicalExpressionsOptimizer::addInExpression(const DisjunctiveEqualityChain /// Construct a list of literals `x1, ..., xN` from the string `expr = x1 OR ... OR expr = xN` ASTPtr value_list = std::make_shared(); - for (const auto function : equality_functions) + for (auto * const function : equality_functions) { const auto & operands = getFunctionOperands(function); value_list->children.push_back(operands[1]); @@ -254,15 +254,15 @@ void LogicalExpressionsOptimizer::addInExpression(const DisjunctiveEqualityChain value_list->children.begin(), value_list->children.end(), [](const DB::ASTPtr & lhs, const DB::ASTPtr & rhs) { - const auto val_lhs = static_cast(&*lhs); - const auto val_rhs = static_cast(&*rhs); + const auto * const val_lhs = static_cast(&*lhs); + const auto * const val_rhs = static_cast(&*rhs); return val_lhs->value < val_rhs->value; }); /// Get the expression `expr` from the chain `expr = x1 OR ... OR expr = xN` ASTPtr equals_expr_lhs; { - auto function = equality_functions[0]; + auto * function = equality_functions[0]; const auto & operands = getFunctionOperands(function); equals_expr_lhs = operands[0]; } @@ -331,7 +331,7 @@ void LogicalExpressionsOptimizer::cleanupOrExpressions() /// Delete garbage. for (const auto & entry : garbage_map) { - auto function = entry.first; + auto * function = entry.first; auto first_erased = entry.second; auto & operands = getFunctionOperands(function); @@ -348,7 +348,7 @@ void LogicalExpressionsOptimizer::fixBrokenOrExpressions() continue; const auto & or_with_expression = chain.first; - auto or_function = or_with_expression.or_function; + auto * or_function = or_with_expression.or_function; auto & operands = getFunctionOperands(or_with_expression.or_function); if (operands.size() == 1) @@ -381,12 +381,10 @@ void LogicalExpressionsOptimizer::fixBrokenOrExpressions() parent->children.erase(first_erased, parent->children.end()); } - /// If the OR node was the root of the WHERE, PREWHERE, or HAVING expression, then update this root. + /// If the OR node was the root of the WHERE, or HAVING expression, then update this root. /// Due to the fact that we are dealing with a directed acyclic graph, we must check all cases. if (select_query->where_expression && (or_function == &*(select_query->where_expression))) select_query->where_expression = operands[0]; - if (select_query->prewhere_expression && (or_function == &*(select_query->prewhere_expression))) - select_query->prewhere_expression = operands[0]; if (select_query->having_expression && (or_function == &*(select_query->having_expression))) select_query->having_expression = operands[0]; } diff --git a/dbms/src/Interpreters/LogicalExpressionsOptimizer.h b/dbms/src/Interpreters/LogicalExpressionsOptimizer.h index 697ce9464e6..084e5a1e997 100644 --- a/dbms/src/Interpreters/LogicalExpressionsOptimizer.h +++ b/dbms/src/Interpreters/LogicalExpressionsOptimizer.h @@ -84,7 +84,7 @@ class LogicalExpressionsOptimizer final bool mayOptimizeDisjunctiveEqualityChain(const DisjunctiveEqualityChain & chain) const; /// Insert the IN expression into the OR chain. - void addInExpression(const DisjunctiveEqualityChain & chain); + static void addInExpression(const DisjunctiveEqualityChain & chain); /// Delete the equalities that were replaced by the IN expressions. void cleanupOrExpressions(); diff --git a/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp b/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp index 7bec95bd4e9..daf5172fa71 100644 --- a/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp +++ b/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp @@ -17,10 +17,7 @@ #include #include #include - -#include "Functions/FunctionBinaryArithmetic.h" -#include "Interpreters/SemiJoinHelper.h" -#include "Parsers/ASTTablesInSelectQuery.h" +#include namespace DB { diff --git a/dbms/src/Interpreters/SemiJoinHelper.cpp b/dbms/src/Interpreters/SemiJoinHelper.cpp index 7a489bea541..918996a7efe 100644 --- a/dbms/src/Interpreters/SemiJoinHelper.cpp +++ b/dbms/src/Interpreters/SemiJoinHelper.cpp @@ -415,7 +415,7 @@ Block SemiJoinHelper::genJoinResult(const NameSet & outp left_semi_null_map = &left_semi_column->getNullMapColumn().getData(); if constexpr (STRICTNESS == ASTTableJoin::Strictness::Any) { - left_semi_null_map->resize_fill(probe_rows, 0); + left_semi_null_map->resize_fill_zero(probe_rows); } else { @@ -449,21 +449,10 @@ Block SemiJoinHelper::genJoinResult(const NameSet & outp } else { - switch (result) - { - case SemiJoinResultType::FALSE_VALUE: - left_semi_column_data->push_back(0); - left_semi_null_map->push_back(0); - break; - case SemiJoinResultType::TRUE_VALUE: - left_semi_column_data->push_back(1); - left_semi_null_map->push_back(0); - break; - case SemiJoinResultType::NULL_VALUE: - left_semi_column_data->push_back(0); - left_semi_null_map->push_back(1); - break; - } + Int8 res = result == SemiJoinResultType::TRUE_VALUE ? 1 : 0; + UInt8 is_null = result == SemiJoinResultType::NULL_VALUE ? 1 : 0; + left_semi_column_data->push_back(res); + left_semi_null_map->push_back(is_null); } } } diff --git a/dbms/src/Interpreters/tests/gtest_interpreter_create_query.cpp b/dbms/src/Interpreters/tests/gtest_interpreter_create_query.cpp index 132893422ef..dd767fd34b6 100644 --- a/dbms/src/Interpreters/tests/gtest_interpreter_create_query.cpp +++ b/dbms/src/Interpreters/tests/gtest_interpreter_create_query.cpp @@ -93,11 +93,8 @@ class InterperCreateQueryTiFlashTest : public ::testing::Test } } - static DB::ASTPtr getASTCreateQuery() + static DB::ASTPtr getASTCreateQuery(const String & stmt) { - String stmt - = R"json(CREATE TABLE `db_2`.`t_88`(`a` Nullable(Int32), `b` Nullable(Int32), `_tidb_rowid` Int64) Engine = DeltaMerge((`_tidb_rowid`), '{"cols":[{"comment":"","default":null,"default_bit":null,"id":1,"name":{"L":"a","O":"a"},"offset":0,"origin_default":null,"state":5,"type":{"Charset":"binary","Collate":"binary","Decimal":0,"Elems":null,"Flag":0,"Flen":11,"Tp":3}},{"comment":"","default":null,"default_bit":null,"id":2,"name":{"L":"b","O":"b"},"offset":1,"origin_default":null,"state":5,"type":{"Charset":"binary","Collate":"binary","Decimal":0,"Elems":null,"Flag":0,"Flen":11,"Tp":3}}],"comment":"","id":88,"index_info":[],"is_common_handle":false,"keyspace_id":4294967295,"name":{"L":"t1","O":"t1"},"partition":null,"pk_is_handle":false,"schema_version":-1,"state":5,"tiflash_replica":{"Available":false,"Count":1},"update_timestamp":442125004587401229}'))json"; - String table_info_json = R"json({"id":88,"name":{"O":"t1","L":"t1"},"charset":"utf8mb4","collate":"utf8mb4_bin","cols":[{"id":1,"name":{"O":"a","L":"a"},"offset":0,"origin_default":null,"origin_default_bit":null,"default":null,"default_bit":null,"default_is_expr":false,"generated_expr_string":"","generated_stored":false,"dependences":null,"type":{"Tp":3,"Flag":0,"Flen":11,"Decimal":0,"Charset":"binary","Collate":"binary","Elems":null,"ElemsIsBinaryLit":null,"Array":false},"state":5,"comment":"","hidden":false,"change_state_info":null,"version":2},{"id":2,"name":{"O":"b","L":"b"},"offset":1,"origin_default":null,"origin_default_bit":null,"default":null, "default_bit":null,"default_is_expr":false,"generated_expr_string":"","generated_stored":false,"dependences":null,"type":{"Tp":3,"Flag":0,"Flen":11,"Decimal":0,"Charset":"binary","Collate":"binary","Elems":null,"ElemsIsBinaryLit":null,"Array":false},"state":5,"comment":"","hidden":false,"change_state_info":null,"version":2}],"index_info":null,"constraint_info":null,"fk_info":null,"state":5,"pk_is_handle":false,"is_common_handle":false,"common_handle_version":0,"comment":"","auto_inc_id":0,"auto_id_cache":0,"auto_rand_id":0,"max_col_id":2,"max_idx_id":0,"max_fk_id":0,"max_cst_id":0,"update_timestamp":442125004587401229,"ShardRowIDBits":0,"max_shard_row_id_bits":0,"auto_random_bits":0,"auto_random_range_bits":0,"pre_split_regions":0, "partition":null,"compression":"","view":null,"sequence":null,"Lock":null,"version":5,"tiflash_replica":{"Count":1,"LocationLabels":[],"Available":false,"AvailablePartitionIDs":null},"is_columnar":false,"temp_table_type":0,"cache_table_status":0,"policy_ref_info":null,"stats_options":null,"exchange_partition_info":null,"ttl_info":null})json"; @@ -146,7 +143,10 @@ try for (auto & thread : threads) { thread = std::thread([&] { - auto ast = getASTCreateQuery(); + // The `stmt` should live longer than `ast` + String stmt + = R"json(CREATE TABLE `db_2`.`t_88`(`a` Nullable(Int32), `b` Nullable(Int32), `_tidb_rowid` Int64) Engine = DeltaMerge((`_tidb_rowid`), '{"cols":[{"comment":"","default":null,"default_bit":null,"id":1,"name":{"L":"a","O":"a"},"offset":0,"origin_default":null,"state":5,"type":{"Charset":"binary","Collate":"binary","Decimal":0,"Elems":null,"Flag":0,"Flen":11,"Tp":3}},{"comment":"","default":null,"default_bit":null,"id":2,"name":{"L":"b","O":"b"},"offset":1,"origin_default":null,"state":5,"type":{"Charset":"binary","Collate":"binary","Decimal":0,"Elems":null,"Flag":0,"Flen":11,"Tp":3}}],"comment":"","id":88,"index_info":[],"is_common_handle":false,"keyspace_id":4294967295,"name":{"L":"t1","O":"t1"},"partition":null,"pk_is_handle":false,"schema_version":-1,"state":5,"tiflash_replica":{"Available":false,"Count":1},"update_timestamp":442125004587401229}'))json"; + auto ast = getASTCreateQuery(stmt); InterpreterCreateQuery interpreter(ast, context); interpreter.setInternal(true); interpreter.setForceRestoreData(false); diff --git a/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp b/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp index f973aa439df..efd8ce06bbc 100644 --- a/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp +++ b/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp @@ -50,7 +50,7 @@ void HashJoinV2ProbeTransformOp::operateSuffixImpl() OperatorStatus HashJoinV2ProbeTransformOp::onOutput(Block & block) { - assert(!probe_context.isCurrentProbeFinished()); + assert(!probe_context.isAllFinished()); block = join_ptr->probeBlock(probe_context, op_index); size_t rows = block.rows(); joined_rows += rows; @@ -59,7 +59,7 @@ OperatorStatus HashJoinV2ProbeTransformOp::onOutput(Block & block) OperatorStatus HashJoinV2ProbeTransformOp::transformImpl(Block & block) { - assert(probe_context.isCurrentProbeFinished()); + assert(probe_context.isAllFinished()); if unlikely (!block) { join_ptr->finishOneProbe(op_index); @@ -80,7 +80,7 @@ OperatorStatus HashJoinV2ProbeTransformOp::tryOutputImpl(Block & block) block = {}; return OperatorStatus::HAS_OUTPUT; } - if (probe_context.isCurrentProbeFinished()) + if (probe_context.isAllFinished()) return OperatorStatus::NEED_INPUT; return onOutput(block); } diff --git a/dbms/src/Operators/MergeSortTransformOp.cpp b/dbms/src/Operators/MergeSortTransformOp.cpp index 6e25bee9658..47414eb6fa3 100644 --- a/dbms/src/Operators/MergeSortTransformOp.cpp +++ b/dbms/src/Operators/MergeSortTransformOp.cpp @@ -149,6 +149,11 @@ OperatorStatus MergeSortTransformOp::transformImpl(Block & block) // store the sorted block in `sorted_blocks`. SortHelper::removeConstantsFromBlock(block); + RUNTIME_CHECK_MSG( + block.columns() == header_without_constants.columns(), + "Unexpected number of constant columns in block in MergeSortTransformOp, n_block={}, n_head={}", + block.columns(), + header_without_constants.columns()); sum_bytes_in_blocks += block.estimateBytesForSpill(); sorted_blocks.emplace_back(std::move(block)); diff --git a/dbms/src/Parsers/ASTInsertQuery.h b/dbms/src/Parsers/ASTInsertQuery.h index 6210504bcef..75db5e82fe5 100644 --- a/dbms/src/Parsers/ASTInsertQuery.h +++ b/dbms/src/Parsers/ASTInsertQuery.h @@ -44,9 +44,6 @@ class ASTInsertQuery : public IAST ASTPtr table_function; ASTPtr partition_expression_list; - // Set to true if the data should only be inserted into attached views - bool no_destination = false; - /// Data to insert const char * data = nullptr; const char * end = nullptr; @@ -57,7 +54,7 @@ class ASTInsertQuery : public IAST bool is_delete = false; /** Get the text that identifies this element. */ - String getID() const override { return "InsertQuery_" + database + "_" + table; }; + String getID() const override { return "InsertQuery_" + database + "_" + table; } ASTPtr clone() const override { diff --git a/dbms/src/Parsers/ASTSelectQuery.cpp b/dbms/src/Parsers/ASTSelectQuery.cpp index d722e40a25a..5366c57489b 100644 --- a/dbms/src/Parsers/ASTSelectQuery.cpp +++ b/dbms/src/Parsers/ASTSelectQuery.cpp @@ -58,7 +58,6 @@ ASTPtr ASTSelectQuery::clone() const CLONE(with_expression_list) CLONE(select_expression_list) CLONE(tables) - CLONE(prewhere_expression) CLONE(where_expression) CLONE(group_expression_list) CLONE(having_expression) @@ -118,13 +117,6 @@ void ASTSelectQuery::formatImpl(const FormatSettings & s, FormatState & state, F segment_expression_list->formatImpl(s, state, frame); } - if (prewhere_expression) - { - s.ostr << (s.hilite ? hilite_keyword : "") << s.nl_or_ws << indent_str << "PREWHERE " - << (s.hilite ? hilite_none : ""); - prewhere_expression->formatImpl(s, state, frame); - } - if (where_expression) { s.ostr << (s.hilite ? hilite_keyword : "") << s.nl_or_ws << indent_str << "WHERE " diff --git a/dbms/src/Parsers/ASTSelectQuery.h b/dbms/src/Parsers/ASTSelectQuery.h index 66d9d795e28..96e67094b31 100644 --- a/dbms/src/Parsers/ASTSelectQuery.h +++ b/dbms/src/Parsers/ASTSelectQuery.h @@ -30,7 +30,7 @@ class ASTSelectQuery : public IAST { public: /** Get the text that identifies this element. */ - String getID() const override { return "SelectQuery"; }; + String getID() const override { return "SelectQuery"; } ASTPtr clone() const override; @@ -42,7 +42,6 @@ class ASTSelectQuery : public IAST ASTPtr tables; ASTPtr partition_expression_list; ASTPtr segment_expression_list; - ASTPtr prewhere_expression; ASTPtr where_expression; ASTPtr group_expression_list; ASTPtr having_expression; diff --git a/dbms/src/Parsers/ExpressionElementParsers.cpp b/dbms/src/Parsers/ExpressionElementParsers.cpp index e7b13abdc0a..155b5d2ece5 100644 --- a/dbms/src/Parsers/ExpressionElementParsers.cpp +++ b/dbms/src/Parsers/ExpressionElementParsers.cpp @@ -647,9 +647,9 @@ bool ParserLiteral::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) const char * ParserAliasBase::restricted_keywords[] - = {"FROM", "FINAL", "SAMPLE", "ARRAY", "LEFT", "RIGHT", "INNER", "FULL", "CROSS", "JOIN", - "GLOBAL", "ANY", "ALL", "ON", "USING", "PREWHERE", "WHERE", "GROUP", "WITH", "HAVING", - "ORDER", "LIMIT", "SETTINGS", "FORMAT", "UNION", "INTO", "PARTITION", "SEGMENT", nullptr}; + = {"FROM", "FINAL", "SAMPLE", "ARRAY", "LEFT", "RIGHT", "INNER", "FULL", "CROSS", "JOIN", + "GLOBAL", "ANY", "ALL", "ON", "USING", "WHERE", "GROUP", "WITH", "HAVING", // + "ORDER", "LIMIT", "SETTINGS", "FORMAT", "UNION", "INTO", "PARTITION", "SEGMENT", nullptr}; template bool ParserAliasImpl::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) diff --git a/dbms/src/Parsers/ParserSelectQuery.cpp b/dbms/src/Parsers/ParserSelectQuery.cpp index ae86a9be406..4f828157ff4 100644 --- a/dbms/src/Parsers/ParserSelectQuery.cpp +++ b/dbms/src/Parsers/ParserSelectQuery.cpp @@ -47,7 +47,6 @@ bool ParserSelectQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) ParserKeyword s_from("FROM"); ParserKeyword s_partition("PARTITION"); ParserKeyword s_segment("SEGMENT"); - ParserKeyword s_prewhere("PREWHERE"); ParserKeyword s_where("WHERE"); ParserKeyword s_group_by("GROUP BY"); ParserKeyword s_with("WITH"); @@ -114,13 +113,6 @@ bool ParserSelectQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) return false; } - /// PREWHERE expr - if (s_prewhere.ignore(pos, expected)) - { - if (!exp_elem.parse(pos, select_query->prewhere_expression, expected)) - return false; - } - /// WHERE expr if (s_where.ignore(pos, expected)) { @@ -217,8 +209,6 @@ bool ParserSelectQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) select_query->children.push_back(select_query->select_expression_list); if (select_query->tables) select_query->children.push_back(select_query->tables); - if (select_query->prewhere_expression) - select_query->children.push_back(select_query->prewhere_expression); if (select_query->where_expression) select_query->children.push_back(select_query->where_expression); if (select_query->group_expression_list) diff --git a/dbms/src/Server/Client.cpp b/dbms/src/Server/Client.cpp index 5d83a478515..65c21a4250b 100644 --- a/dbms/src/Server/Client.cpp +++ b/dbms/src/Server/Client.cpp @@ -16,7 +16,6 @@ #include #include #include -#include #include #include #include @@ -51,6 +50,7 @@ #include #include #include +#include #include #include #include @@ -59,7 +59,6 @@ #include #include -#include #include #include #include @@ -179,10 +178,6 @@ class Client : public Poco::Util::Application size_t written_progress_chars = 0; bool written_first_block = false; - /// External tables info. - std::list external_tables; - - struct ConnectionParameters { String host; @@ -734,28 +729,10 @@ class Client : public Poco::Util::Application return true; } - - /// Convert external tables to ExternalTableData and send them using the connection. - void sendExternalTables() - { - const auto * select = typeid_cast(&*parsed_query); - if (!select && !external_tables.empty()) - throw Exception("External tables could be sent only with select query", ErrorCodes::BAD_ARGUMENTS); - - std::vector data; - for (auto & table : external_tables) - data.emplace_back(table.getData(*context)); - - connection->sendExternalTablesData(data); - } - - /// Process the query that doesn't require transfering data blocks to the server. void processOrdinaryQuery() { - connection - ->sendQuery(query, query_id, QueryProcessingStage::Complete, &context->getSettingsRef(), nullptr, true); - sendExternalTables(); + connection->sendQuery(query, query_id, QueryProcessingStage::Complete, &context->getSettingsRef(), nullptr); receiveResult(); } @@ -776,9 +753,7 @@ class Client : public Poco::Util::Application query_id, QueryProcessingStage::Complete, &context->getSettingsRef(), - nullptr, - true); - sendExternalTables(); + nullptr); /// Receive description of table structure. Block sample; @@ -1228,56 +1203,16 @@ class Client : public Poco::Util::Application /** We allow different groups of arguments: * - common arguments; - * - arguments for any number of external tables each in form "--external args...", - * where possible args are file, name, format, structure, types. * Split these groups before processing. */ using Arguments = std::vector; Arguments common_arguments{""}; /// 0th argument is ignored. - std::vector external_tables_arguments; - bool in_external_group = false; for (int arg_num = 1; arg_num < argc; ++arg_num) { const char * arg = argv[arg_num]; - - if (0 == strcmp(arg, "--external")) - { - in_external_group = true; - external_tables_arguments.emplace_back(Arguments{""}); - } - /// Options with value after equal sign. - else if ( - in_external_group - && (0 == strncmp(arg, "--file=", strlen("--file=")) || 0 == strncmp(arg, "--name=", strlen("--name=")) - || 0 == strncmp(arg, "--format=", strlen("--format=")) - || 0 == strncmp(arg, "--structure=", strlen("--structure=")) - || 0 == strncmp(arg, "--types=", strlen("--types=")))) - { - external_tables_arguments.back().emplace_back(arg); - } - /// Options with value after whitespace. - else if ( - in_external_group - && (0 == strcmp(arg, "--file") || 0 == strcmp(arg, "--name") || 0 == strcmp(arg, "--format") - || 0 == strcmp(arg, "--structure") || 0 == strcmp(arg, "--types"))) - { - if (arg_num + 1 < argc) - { - external_tables_arguments.back().emplace_back(arg); - ++arg_num; - arg = argv[arg_num]; - external_tables_arguments.back().emplace_back(arg); - } - else - break; - } - else - { - in_external_group = false; - common_arguments.emplace_back(arg); - } + common_arguments.emplace_back(arg); } #define DECLARE_SETTING(TYPE, NAME, DEFAULT, DESCRIPTION) \ @@ -1313,18 +1248,6 @@ class Client : public Poco::Util::Application // clang-format on #undef DECLARE_SETTING - /// Commandline options related to external tables. - boost::program_options::options_description external_description("External tables options"); - // clang-format off - external_description.add_options() - ("file", boost::program_options::value(), "data file or - for stdin") - ("name", boost::program_options::value()->default_value("_data"), "name of the table") - ("format", boost::program_options::value()->default_value("TabSeparated"), "data format") - ("structure", boost::program_options::value(), "structure") - ("types", boost::program_options::value(), "types") - ; - // clang-format on - /// Parse main commandline options. boost::program_options::parsed_options parsed = boost::program_options::command_line_parser(common_arguments.size(), common_arguments.data()) @@ -1345,41 +1268,9 @@ class Client : public Poco::Util::Application && options["host"].as() == "elp")) /// If user writes -help instead of --help. { std::cout << main_description << "\n"; - std::cout << external_description << "\n"; exit(0); } - size_t number_of_external_tables_with_stdin_source = 0; - for (size_t i = 0; i < external_tables_arguments.size(); ++i) - { - /// Parse commandline options related to external tables. - boost::program_options::parsed_options parsed = boost::program_options::command_line_parser( - external_tables_arguments[i].size(), - external_tables_arguments[i].data()) - .options(external_description) - .run(); - boost::program_options::variables_map external_options; - boost::program_options::store(parsed, external_options); - - try - { - external_tables.emplace_back(external_options); - if (external_tables.back().file == "-") - ++number_of_external_tables_with_stdin_source; - if (number_of_external_tables_with_stdin_source > 1) - throw Exception( - "Two or more external tables has stdin (-) set as --file field", - ErrorCodes::BAD_ARGUMENTS); - } - catch (const Exception & e) - { - std::string text = e.displayText(); - std::cerr << "Code: " << e.code() << ". " << text << std::endl; - std::cerr << "Table №" << i << std::endl << std::endl; - exit(e.code()); - } - } - /// Extract settings and limits from the options. #define EXTRACT_SETTING(TYPE, NAME, DEFAULT, DESCRIPTION) \ if (options.count(#NAME)) \ diff --git a/dbms/src/Server/DTTool/DTTool.cpp b/dbms/src/Server/DTTool/DTTool.cpp index d0334cb521a..772523a49a7 100644 --- a/dbms/src/Server/DTTool/DTTool.cpp +++ b/dbms/src/Server/DTTool/DTTool.cpp @@ -13,8 +13,8 @@ // limitations under the License. #include +#include -#include #include namespace bpo = boost::program_options; diff --git a/dbms/src/Server/DTTool/DTToolBench.cpp b/dbms/src/Server/DTTool/DTToolBench.cpp index 040a668708c..8dbbaefaf9e 100644 --- a/dbms/src/Server/DTTool/DTToolBench.cpp +++ b/dbms/src/Server/DTTool/DTToolBench.cpp @@ -30,9 +30,9 @@ #include #include #include +#include #include -#include #include #include #include diff --git a/dbms/src/Server/DTTool/DTToolInspect.cpp b/dbms/src/Server/DTTool/DTToolInspect.cpp index c906d15c36b..fdacd459d42 100644 --- a/dbms/src/Server/DTTool/DTToolInspect.cpp +++ b/dbms/src/Server/DTTool/DTToolInspect.cpp @@ -21,9 +21,9 @@ #include #include #include +#include #include -#include #include #include diff --git a/dbms/src/Server/DTTool/DTToolMigrate.cpp b/dbms/src/Server/DTTool/DTToolMigrate.cpp index 93fe1d9e7fc..ffb975f8175 100644 --- a/dbms/src/Server/DTTool/DTToolMigrate.cpp +++ b/dbms/src/Server/DTTool/DTToolMigrate.cpp @@ -19,9 +19,9 @@ #include #include #include +#include #include -#include #include namespace DTTool::Migrate diff --git a/dbms/src/Server/Server.cpp b/dbms/src/Server/Server.cpp index 4337a0dbc2c..88adbf40762 100644 --- a/dbms/src/Server/Server.cpp +++ b/dbms/src/Server/Server.cpp @@ -582,7 +582,7 @@ try /// Initialize users config reloader. auto users_config_reloader = UserConfig::parseSettings(config(), config_path, global_context, log); - /// Load global settings from default_profile and system_profile. + /// Load global settings from default_profile /// It internally depends on UserConfig::parseSettings. // TODO: Parse the settings from config file at the program beginning global_context->setDefaultProfiles(); diff --git a/dbms/src/Server/TCPHandler.cpp b/dbms/src/Server/TCPHandler.cpp index 2c52dd7810f..88994fe2f1a 100644 --- a/dbms/src/Server/TCPHandler.cpp +++ b/dbms/src/Server/TCPHandler.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include #include #include #include @@ -685,22 +685,7 @@ bool TCPHandler::receiveData() { /// If there is an insert request, then the data should be written directly to `state.io.out`. /// Otherwise, we write the blocks in the temporary `external_table_name` table. - if (!state.need_receive_data_for_insert) - { - StoragePtr storage; - /// If such a table does not exist, create it. - if (!(storage = query_context.tryGetExternalTable(external_table_name))) - { - NamesAndTypesList columns = block.getNamesAndTypesList(); - storage = StorageMemory::create( - external_table_name, - ColumnsDescription{columns, NamesAndTypesList{}, NamesAndTypesList{}, ColumnDefaults{}}); - storage->startup(); - query_context.addExternalTable(external_table_name, storage); - } - /// The data will be written directly to the table. - state.io.out = storage->write(ASTPtr(), query_context.getSettingsRef()); - } + RUNTIME_CHECK_MSG(state.need_receive_data_for_insert, "Does not support write the blocks into external table"); if (block) state.io.out->write(block); return true; diff --git a/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilter.h b/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilter.h index eb5154bd231..bfd0c97d51d 100644 --- a/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilter.h +++ b/dbms/src/Storages/DeltaMerge/BitmapFilter/BitmapFilter.h @@ -32,6 +32,7 @@ class BitmapFilter void set(BlockInputStreamPtr & stream); // f[start, satrt+limit) = value void set(UInt32 start, UInt32 limit, bool value = true); + void set(std::span row_ids, const FilterPtr & f); // If return true, all data is match and do not fill the filter. bool get(IColumn::Filter & f, UInt32 start, UInt32 limit) const; // Caller should ensure n in [0, size). @@ -48,8 +49,6 @@ class BitmapFilter friend class BitmapFilterView; private: - void set(std::span row_ids, const FilterPtr & f); - IColumn::Filter filter; bool all_match; }; diff --git a/dbms/src/Storages/DeltaMerge/CMakeLists.txt b/dbms/src/Storages/DeltaMerge/CMakeLists.txt index d132c8e1b98..284107203b3 100644 --- a/dbms/src/Storages/DeltaMerge/CMakeLists.txt +++ b/dbms/src/Storages/DeltaMerge/CMakeLists.txt @@ -22,6 +22,8 @@ add_headers_and_sources(delta_merge .) add_headers_and_sources(delta_merge ./BitmapFilter) add_headers_and_sources(delta_merge ./Index) add_headers_and_sources(delta_merge ./Index/VectorIndex) +add_headers_and_sources(delta_merge ./Index/VectorIndex/Stream) +add_headers_and_sources(delta_merge ./Index/InvertedIndex) add_headers_and_sources(delta_merge ./Filter) add_headers_and_sources(delta_merge ./FilterParser) add_headers_and_sources(delta_merge ./File) diff --git a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileInputStream.cpp b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileInputStream.cpp new file mode 100644 index 00000000000..678370a2953 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileInputStream.cpp @@ -0,0 +1,70 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +namespace DB::DM +{ + +inline size_t ColumnFileInputStream::skipNextBlock() +{ + if (!reader) + return 0; + return reader->skipNextBlock(); +} + +Block ColumnFileInputStream::readWithFilter(const IColumn::Filter & filter) +{ + if (!reader) + return {}; + auto block = read(); + if (size_t passed_count = countBytesInFilter(filter); passed_count != block.rows()) + { + for (auto & col : block) + { + col.column = col.column->filter(filter, passed_count); + } + } + return block; +} + +inline Block ColumnFileInputStream::getHeader() const +{ + return toEmptyBlock(*col_defs); +} + +inline Block ColumnFileInputStream::read() +{ + if (!reader) + return {}; + return reader->readNextBlock(); +} + +ColumnFileInputStream::ColumnFileInputStream( + const DMContext & context_, + const ColumnFilePtr & column_file, + const IColumnFileDataProviderPtr & data_provider_, + const ColumnDefinesPtr & col_defs_, + ReadTag read_tag_) + : col_defs(col_defs_) + // Note that ColumnFileDelete does not have a reader, so that the reader will be nullptr. + , reader(column_file->getReader(context_, data_provider_, col_defs_, read_tag_)) +{ + RUNTIME_CHECK(col_defs != nullptr); +} + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileInputStream.h b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileInputStream.h new file mode 100644 index 00000000000..a5103709d17 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileInputStream.h @@ -0,0 +1,76 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace DB::DM +{ + +class ColumnFile; +using ColumnFilePtr = std::shared_ptr; +class ColumnFileReader; +using ColumnFileReaderPtr = std::shared_ptr; + +class ColumnFileInputStream : public SkippableBlockInputStream +{ +public: // Implements SkippableBlockInputStream + bool getSkippedRows(size_t &) override { throw Exception("Not implemented", ErrorCodes::NOT_IMPLEMENTED); } + + size_t skipNextBlock() override; + + Block readWithFilter(const IColumn::Filter & filter) override; + +public: // Implements IBlockInputStream + String getName() const override { return "ColumnFile"; } + + Block getHeader() const override; + + // Note: The output block does not contain a start offset. + Block read() override; + +public: + static ColumnFileInputStreamPtr create( + const DMContext & context_, + const ColumnFilePtr & column_file, + const IColumnFileDataProviderPtr & data_provider_, + const ColumnDefinesPtr & col_defs_, + ReadTag read_tag_) + { + return std::make_shared(context_, column_file, data_provider_, col_defs_, read_tag_); + } + + explicit ColumnFileInputStream( + const DMContext & context_, + const ColumnFilePtr & column_file, + const IColumnFileDataProviderPtr & data_provider_, + const ColumnDefinesPtr & col_defs_, + ReadTag read_tag_); + +private: + // There could be possibly a lot of ColumnFiles. + // So we keep this struct as small as possible. + + ColumnDefinesPtr col_defs; + ColumnFileReaderPtr reader; +}; + +} // namespace DB::DM diff --git a/libs/libcommon/src/tests/date_lut_init.cpp b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileInputStream_fwd.h similarity index 71% rename from libs/libcommon/src/tests/date_lut_init.cpp rename to dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileInputStream_fwd.h index f9eb3a5a6b4..b9d7a698773 100644 --- a/libs/libcommon/src/tests/date_lut_init.cpp +++ b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileInputStream_fwd.h @@ -1,4 +1,4 @@ -// Copyright 2023 PingCAP, Inc. +// Copyright 2025 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,11 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#pragma once -/// Позволяет проверить время инициализации DateLUT. -int main(int argc, char ** argv) +#include + +namespace DB::DM { - DateLUT::instance(); - return 0; -} + +class ColumnFileInputStream; + +using ColumnFileInputStreamPtr = std::shared_ptr; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileSetReader.h b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileSetReader.h index e0aa6cc4591..cea9c4a743b 100644 --- a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileSetReader.h +++ b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileSetReader.h @@ -25,7 +25,7 @@ namespace DB::DM class ColumnFileSetReader { friend class ColumnFileSetInputStream; - friend class ColumnFileSetWithVectorIndexInputStream; + friend class VectorIndexColumnFileSetInputStream; private: const DMContext & context; diff --git a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileSetWithVectorIndexInputStream.cpp b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileSetWithVectorIndexInputStream.cpp deleted file mode 100644 index 7df083f79e2..00000000000 --- a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileSetWithVectorIndexInputStream.cpp +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - - -namespace DB::DM -{ - -SkippableBlockInputStreamPtr ColumnFileSetWithVectorIndexInputStream::tryBuild( - const DMContext & context, - const ColumnFileSetSnapshotPtr & delta_snap, - const ColumnDefinesPtr & col_defs, - const RowKeyRange & segment_range_, - const IColumnFileDataProviderPtr & data_provider, - const ANNQueryInfoPtr & ann_query_info, - const BitmapFilterPtr & bitmap_filter, - size_t offset, - ReadTag read_tag_) -{ - auto fallback = [&]() { - return std::make_shared(context, delta_snap, col_defs, segment_range_, read_tag_); - }; - - if (!bitmap_filter || !ann_query_info) - return fallback(); - - std::optional vec_cd; - auto rest_columns = std::make_shared(); - rest_columns->reserve(col_defs->size() - 1); - for (const auto & cd : *col_defs) - { - if (cd.id == ann_query_info->column_id()) - vec_cd.emplace(cd); - else - rest_columns->emplace_back(cd); - } - - // No vector index column is specified, fallback. - if (!vec_cd.has_value()) - return fallback(); - - bool has_vector_index = false; - for (const auto & file : delta_snap->getColumnFiles()) - { - if (auto * tiny_file = file->tryToTinyFile(); tiny_file && tiny_file->hasIndex(ann_query_info->index_id())) - { - has_vector_index = true; - break; - } - } - // No file has vector index, fallback. - if (!has_vector_index) - return fallback(); - - // All check passed. Let's read via vector index. - return std::make_shared( - context, - delta_snap, - col_defs, - segment_range_, - data_provider, - ann_query_info, - BitmapFilterView(bitmap_filter, offset, delta_snap->getRows()), - std::move(*vec_cd), - rest_columns, - read_tag_); -} - -Block ColumnFileSetWithVectorIndexInputStream::readOtherColumns() -{ - auto reset_column_file_reader = (*cur_column_file_reader)->createNewReader(rest_col_defs, ReadTag::Query); - Block block = reset_column_file_reader->readNextBlock(); - return block; -} - -void ColumnFileSetWithVectorIndexInputStream::toNextFile(size_t current_file_index, size_t current_file_rows) -{ - (*cur_column_file_reader).reset(); - ++cur_column_file_reader; - read_rows += current_file_rows; - tiny_readers[current_file_index].reset(); -} - -Block ColumnFileSetWithVectorIndexInputStream::read() -{ - load(); - - while (cur_column_file_reader != reader.column_file_readers.end()) - { - // Skip ColumnFileDeleteRange - if (*cur_column_file_reader == nullptr) - { - ++cur_column_file_reader; - continue; - } - auto current_file_index = std::distance(reader.column_file_readers.begin(), cur_column_file_reader); - // If file has index, we can read the column by vector index. - if (tiny_readers[current_file_index] != nullptr) - { - const auto file_rows = column_files[current_file_index]->getRows(); - auto selected_row_begin = std::lower_bound(sorted_results.cbegin(), sorted_results.cend(), read_rows); - auto selected_row_end = std::lower_bound(selected_row_begin, sorted_results.cend(), read_rows + file_rows); - size_t selected_rows = std::distance(selected_row_begin, selected_row_end); - // If all rows are filtered out, skip this file. - if (selected_rows == 0) - { - toNextFile(current_file_index, file_rows); - continue; - } - - // read vector type column by vector index - auto tiny_reader = tiny_readers[current_file_index]; - auto vec_column = vec_cd.type->createColumn(); - const std::span file_selected_rows{selected_row_begin, selected_row_end}; - tiny_reader->read(vec_column, file_selected_rows, /* rowid_start_offset= */ read_rows); - assert(vec_column->size() == selected_rows); - - // read other columns if needed - Block block; - if (!rest_col_defs->empty()) - { - block = readOtherColumns(); - filter.clear(); - filter.resize_fill(file_rows, 0); - for (const auto rowid : file_selected_rows) - filter[rowid - read_rows] = 1; - for (auto & col : block) - col.column = col.column->filter(filter, selected_rows); - - assert(block.rows() == selected_rows); - } - - auto index = header.getPositionByName(vec_cd.name); - block.insert(index, ColumnWithTypeAndName(std::move(vec_column), vec_cd.type, vec_cd.name)); - - // All rows in this ColumnFileTiny have been read. - block.setStartOffset(read_rows); - toNextFile(current_file_index, file_rows); - return block; - } - // If file does not have index, reader by cur_column_file_reader. - auto block = (*cur_column_file_reader)->readNextBlock(); - if (block) - { - block.setStartOffset(read_rows); - size_t rows = block.rows(); - filter = valid_rows.getRawSubFilter(read_rows, rows); - size_t passed_count = countBytesInFilter(filter); - for (auto & col : block) - col.column = col.column->filter(filter, passed_count); - read_rows += rows; - return block; - } - else - { - (*cur_column_file_reader).reset(); - ++cur_column_file_reader; - } - } - return {}; -} - -std::vector ColumnFileSetWithVectorIndexInputStream::load() -{ - if (loaded) - return {}; - - tiny_readers.reserve(column_files.size()); - UInt32 precedes_rows = 0; - std::vector search_results; - for (const auto & column_file : column_files) - { - if (auto * tiny_file = column_file->tryToTinyFile(); - tiny_file && tiny_file->hasIndex(ann_query_info->index_id())) - { - auto tiny_reader = std::make_shared( - *tiny_file, - data_provider, - ann_query_info, - valid_rows.createSubView(precedes_rows, tiny_file->getRows()), - vec_cd, - vec_index_cache); - auto sr = tiny_reader->load(); - for (auto & row : sr) - row.key += precedes_rows; - search_results.insert(search_results.end(), sr.begin(), sr.end()); - tiny_readers.push_back(tiny_reader); - // avoid virutal function call - precedes_rows += tiny_file->getRows(); - } - else - { - tiny_readers.push_back(nullptr); - precedes_rows += column_file->getRows(); - } - } - // Keep the top k minimum distances rows. - auto select_size - = search_results.size() > ann_query_info->top_k() ? ann_query_info->top_k() : search_results.size(); - auto top_k_end = search_results.begin() + select_size; - std::nth_element(search_results.begin(), top_k_end, search_results.end(), [](const auto & lhs, const auto & rhs) { - return lhs.distance < rhs.distance; - }); - search_results.resize(select_size); - // Sort by key again. - std::sort(search_results.begin(), search_results.end(), [](const auto & lhs, const auto & rhs) { - return lhs.key < rhs.key; - }); - - loaded = true; - return search_results; -} - -void ColumnFileSetWithVectorIndexInputStream::setSelectedRows(const std::span & selected_rows) -{ - sorted_results.reserve(selected_rows.size()); - std::copy(selected_rows.begin(), selected_rows.end(), std::back_inserter(sorted_results)); -} - -} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileSetWithVectorIndexInputStream.h b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileSetWithVectorIndexInputStream.h deleted file mode 100644 index 35ede98832d..00000000000 --- a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileSetWithVectorIndexInputStream.h +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - - -namespace DB::DM -{ - -class ColumnFileSetWithVectorIndexInputStream : public VectorIndexBlockInputStream -{ -private: - ColumnFileSetReader reader; - - std::vector::iterator cur_column_file_reader; - size_t read_rows = 0; - - const IColumnFileDataProviderPtr data_provider; - const ANNQueryInfoPtr ann_query_info; - const BitmapFilterView valid_rows; - // Global vector index cache - const LocalIndexCachePtr vec_index_cache; - const ColumnDefine vec_cd; - const ColumnDefinesPtr rest_col_defs; - - // Set after load(). Top K search results in files with vector index. - std::vector sorted_results; - std::vector tiny_readers; - - const ColumnFiles & column_files; - - const Block header; - IColumn::Filter filter; - - bool loaded = false; - -public: - ColumnFileSetWithVectorIndexInputStream( - const DMContext & context_, - const ColumnFileSetSnapshotPtr & delta_snap_, - const ColumnDefinesPtr & col_defs_, - const RowKeyRange & segment_range_, - const IColumnFileDataProviderPtr & data_provider_, - const ANNQueryInfoPtr & ann_query_info_, - const BitmapFilterView && valid_rows_, - ColumnDefine && vec_cd_, - const ColumnDefinesPtr & rest_col_defs_, - ReadTag read_tag_) - : reader(context_, delta_snap_, col_defs_, segment_range_, read_tag_) - , data_provider(data_provider_) - , ann_query_info(ann_query_info_) - , valid_rows(std::move(valid_rows_)) - // ColumnFile vector index stores all data in memory, can not be evicted by system. - , vec_index_cache(context_.global_context.getHeavyLocalIndexCache()) - , vec_cd(std::move(vec_cd_)) - , rest_col_defs(rest_col_defs_) - , column_files(reader.snapshot->getColumnFiles()) - , header(toEmptyBlock(*(reader.col_defs))) - { - cur_column_file_reader = reader.column_file_readers.begin(); - } - - static SkippableBlockInputStreamPtr tryBuild( - const DMContext & context, - const ColumnFileSetSnapshotPtr & delta_snap, - const ColumnDefinesPtr & col_defs, - const RowKeyRange & segment_range_, - const IColumnFileDataProviderPtr & data_provider, - const ANNQueryInfoPtr & ann_query_info, - const BitmapFilterPtr & bitmap_filter, - size_t offset, - ReadTag read_tag_); - - String getName() const override { return "ColumnFileSetWithVectorIndex"; } - Block getHeader() const override { return header; } - - Block read() override; - - std::vector load() override; - - void setSelectedRows(const std::span & selected_rows) override; - -private: - Block readOtherColumns(); - - void toNextFile(size_t current_file_index, size_t current_file_rows); -}; - -} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTiny.cpp b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTiny.cpp index a7f189876b9..a35b6d1038f 100644 --- a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTiny.cpp +++ b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTiny.cpp @@ -415,4 +415,21 @@ ColumnFileTiny::ColumnFileTiny( , file_provider(dm_context.global_context.getFileProvider()) {} +ColumnFileTiny::ColumnFileTiny( + const ColumnFileSchemaPtr & schema_, + UInt64 rows_, + UInt64 bytes_, + PageIdU64 data_page_id_, + KeyspaceID keyspace_id_, + const FileProviderPtr & file_provider_, + const IndexInfosPtr & index_infos_) + : schema(schema_) + , rows(rows_) + , bytes(bytes_) + , data_page_id(data_page_id_) + , index_infos(index_infos_) + , keyspace_id(keyspace_id_) + , file_provider(file_provider_) +{} + } // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTiny.h b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTiny.h index 9eb4a099c66..11806463e25 100644 --- a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTiny.h +++ b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTiny.h @@ -41,7 +41,7 @@ class ColumnFileTiny : public ColumnFilePersisted friend struct Remote::Serializer; using IndexInfos = std::vector; - using IndexInfosPtr = std::shared_ptr; + using IndexInfosPtr = std::shared_ptr; private: ColumnFileSchemaPtr schema; @@ -50,7 +50,7 @@ class ColumnFileTiny : public ColumnFilePersisted UInt64 bytes = 0; /// The id of data page which stores the data of this pack. - PageIdU64 data_page_id; + const PageIdU64 data_page_id; /// HACK: Currently this field is only available when ColumnFileTiny is restored from remote proto. /// It is not available when ColumnFileTiny is constructed or restored locally. @@ -58,7 +58,7 @@ class ColumnFileTiny : public ColumnFilePersisted UInt64 data_page_size = 0; /// The index information of this file. - IndexInfosPtr index_infos; + const IndexInfosPtr index_infos; /// The id of the keyspace which this ColumnFileTiny belongs to. const KeyspaceID keyspace_id; @@ -77,36 +77,61 @@ class ColumnFileTiny : public ColumnFilePersisted const DMContext & dm_context, const IndexInfosPtr & index_infos_ = nullptr); + ColumnFileTiny( + const ColumnFileSchemaPtr & schema_, + UInt64 rows_, + UInt64 bytes_, + PageIdU64 data_page_id_, + KeyspaceID keyspace_id_, + const FileProviderPtr & file_provider_, + const IndexInfosPtr & index_infos_); + Type getType() const override { return Type::TINY_FILE; } size_t getRows() const override { return rows; } size_t getBytes() const override { return bytes; } IndexInfosPtr getIndexInfos() const { return index_infos; } - bool hasIndex(Int64 index_id) const + + bool hasIndex(Int64 index_id) const { return findIndexInfo(index_id) != nullptr; } + + const dtpb::ColumnFileIndexInfo * findIndexInfo(Int64 index_id) const { if (!index_infos) - return false; - return std::any_of(index_infos->cbegin(), index_infos->cend(), [index_id](const auto & info) { - return info.index_props().index_id() == index_id; - }); + return nullptr; + const auto it = std::find_if( // + index_infos->cbegin(), + index_infos->cend(), + [index_id](const auto & info) { return info.index_props().index_id() == index_id; }); + if (it == index_infos->cend()) + return nullptr; + return &*it; } ColumnFileSchemaPtr getSchema() const { return schema; } ColumnFileTinyPtr cloneWith(PageIdU64 new_data_page_id) { - auto new_tiny_file = std::make_shared(*this); - new_tiny_file->data_page_id = new_data_page_id; - return new_tiny_file; + return std::make_shared( + schema, + rows, + bytes, + new_data_page_id, + keyspace_id, + file_provider, + index_infos); } - ColumnFileTinyPtr cloneWith(PageIdU64 new_data_page_id, const IndexInfosPtr & index_infos_) const + ColumnFileTinyPtr cloneWith(PageIdU64 new_data_page_id, const IndexInfosPtr & new_index_infos) const { - auto new_tiny_file = std::make_shared(*this); - new_tiny_file->data_page_id = new_data_page_id; - new_tiny_file->index_infos = index_infos_; - return new_tiny_file; + return std::make_shared( + schema, + rows, + bytes, + new_data_page_id, + keyspace_id, + file_provider, + new_index_infos); } ColumnFileReaderPtr getReader( diff --git a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTinyLocalIndexWriter.cpp b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTinyLocalIndexWriter.cpp index d1854eceb01..5316a5ffe4c 100644 --- a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTinyLocalIndexWriter.cpp +++ b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTinyLocalIndexWriter.cpp @@ -16,8 +16,8 @@ #include #include #include -#include -#include +#include +#include namespace DB::ErrorCodes @@ -94,7 +94,7 @@ ColumnFileTinyPtr ColumnFileTinyLocalIndexWriter::buildIndexForFile( struct IndexToBuild { LocalIndexInfo info; - VectorIndexWriterInMemoryPtr builder_vector; + LocalIndexWriterInMemoryPtr index_writer; }; std::unordered_map> index_builders; @@ -107,7 +107,7 @@ ColumnFileTinyPtr ColumnFileTinyLocalIndexWriter::buildIndexForFile( RUNTIME_CHECK(index_info.def_vector_index != nullptr); index_builders[index_info.column_id].emplace_back(IndexToBuild{ .info = index_info, - .builder_vector = {}, + .index_writer = {}, }); } @@ -125,17 +125,8 @@ ColumnFileTinyPtr ColumnFileTinyLocalIndexWriter::buildIndexForFile( file->getDataPageId()); for (auto & index : indexes) - { - switch (index.info.kind) - { - case TiDB::ColumnarIndexKind::Vector: - index.builder_vector = VectorIndexWriterInMemory::create(index.info.def_vector_index); - break; - default: - RUNTIME_CHECK_MSG(false, "Unsupported index kind: {}", magic_enum::enum_name(index.info.kind)); - break; - } - } + index.index_writer = LocalIndexWriter::createInMemory(index.info); + read_columns->push_back(*cd_iter); } @@ -171,16 +162,8 @@ ColumnFileTinyPtr ColumnFileTinyLocalIndexWriter::buildIndexForFile( const auto & col = col_with_type_and_name.column; for (const auto & index : index_builders[read_columns->at(col_idx).id]) { - switch (index.info.kind) - { - case TiDB::ColumnarIndexKind::Vector: - RUNTIME_CHECK(index.builder_vector); - index.builder_vector->addBlock(*col, del_mark, should_proceed); - break; - default: - RUNTIME_CHECK_MSG(false, "Unsupported index kind: {}", magic_enum::enum_name(index.info.kind)); - break; - } + RUNTIME_CHECK(index.index_writer); + index.index_writer->addBlock(*col, del_mark, should_proceed); } } } @@ -192,45 +175,24 @@ ColumnFileTinyPtr ColumnFileTinyLocalIndexWriter::buildIndexForFile( const auto & cd = read_columns->at(col_idx); for (const auto & index : index_builders[cd.id]) { - switch (index.info.kind) - { - case TiDB::ColumnarIndexKind::Vector: - { - RUNTIME_CHECK(index.builder_vector); - auto index_page_id = options.storage_pool->newLogPageId(); - MemoryWriteBuffer write_buf; - CompressedWriteBuffer compressed(write_buf); - index.builder_vector->finalize(compressed); - compressed.next(); - auto data_size = write_buf.count(); - auto buf = write_buf.tryGetReadBuffer(); - // ColumnFileDataProviderRNLocalPageCache currently does not support read data with fields - options.wbs.log.putPage(index_page_id, 0, buf, data_size, {data_size}); - - auto idx_info = dtpb::ColumnFileIndexInfo{}; - idx_info.set_index_page_id(index_page_id); - auto * idx_props = idx_info.mutable_index_props(); - idx_props->set_kind(dtpb::IndexFileKind::VECTOR_INDEX); - idx_props->set_index_id(index.info.index_id); - idx_props->set_file_size(data_size); - auto * vector_index = idx_props->mutable_vector_index(); - vector_index->set_format_version(0); - vector_index->set_dimensions(index.info.def_vector_index->dimension); - vector_index->set_distance_metric( - tipb::VectorDistanceMetric_Name(index.info.def_vector_index->distance_metric)); - index_infos->emplace_back(std::move(idx_info)); - - break; - } - default: - RUNTIME_CHECK_MSG(false, "Unsupported index kind: {}", magic_enum::enum_name(index.info.kind)); - break; - } + RUNTIME_CHECK(index.index_writer); + auto index_page_id = options.storage_pool->newLogPageId(); + MemoryWriteBuffer write_buf; + CompressedWriteBuffer compressed(write_buf); + dtpb::ColumnFileIndexInfo pb_cf_idx; + pb_cf_idx.set_index_page_id(index_page_id); + auto idx_info = index.index_writer->finalize(compressed, [&write_buf] { return write_buf.count(); }); + pb_cf_idx.mutable_index_props()->Swap(&idx_info); + auto data_size = write_buf.count(); + auto buf = write_buf.tryGetReadBuffer(); + // ColumnFileDataProviderRNLocalPageCache currently does not support read data withiout fields + options.wbs.log.putPage(index_page_id, 0, buf, data_size, {data_size}); + index_infos->emplace_back(std::move(pb_cf_idx)); } } - if (file->index_infos) - file->index_infos->insert(file->index_infos->end(), index_infos->begin(), index_infos->end()); + if (const auto & file_index_info = file->getIndexInfos(); file_index_info) + index_infos->insert(index_infos->end(), file_index_info->begin(), file_index_info->end()); options.wbs.writeLogAndData(); // Note: The id of the file cannot be changed, otherwise minor compaction will fail. diff --git a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTinyVectorIndexReader.cpp b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTinyVectorIndexReader.cpp deleted file mode 100644 index 37b419ae462..00000000000 --- a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTinyVectorIndexReader.cpp +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include -#include -#include - - -namespace DB::DM -{ - -void ColumnFileTinyVectorIndexReader::read( - MutableColumnPtr & vec_column, - const std::span & read_rowids, - size_t rowid_start_offset) -{ - RUNTIME_CHECK(loaded); - - Stopwatch watch; - vec_column->reserve(read_rowids.size()); - std::vector value; - for (const auto & rowid : read_rowids) - { - // Each ColomnFileTiny has its own vector index, rowid_start_offset is the offset of the ColmnFilePersistSet. - vec_index->get(rowid - rowid_start_offset, value); - vec_column->insertData(reinterpret_cast(value.data()), value.size() * sizeof(Float32)); - } - - perf_stat.returned_rows = read_rowids.size(); - perf_stat.read_vec_column_seconds = watch.elapsedSeconds(); -} - -std::vector ColumnFileTinyVectorIndexReader::load() -{ - if (loaded) - return {}; - - Stopwatch watch; - - loadVectorIndex(); - auto search_results = loadVectorSearchResult(); - - perf_stat.load_vec_index_and_results_seconds = watch.elapsedSeconds(); - - loaded = true; - return search_results; -} - -void ColumnFileTinyVectorIndexReader::loadVectorIndex() -{ - const auto & index_infos = tiny_file.index_infos; - if (!index_infos || index_infos->empty()) - return; - auto index_id = ann_query_info->index_id(); - const auto index_info_iter - = std::find_if(index_infos->cbegin(), index_infos->cend(), [index_id](const auto & info) { - return info.index_props().index_id() == index_id; - }); - if (index_info_iter == index_infos->cend()) - return; - if (!index_info_iter->index_props().has_vector_index()) - return; - auto index_page_id = index_info_iter->index_page_id(); - auto load_from_page_storage = [&]() { - perf_stat.load_from_cache = false; - std::vector index_fields = {0}; - auto index_page = data_provider->readTinyData(index_page_id, index_fields); - ReadBufferFromOwnString read_buf(index_page.data); - CompressedReadBuffer compressed(read_buf); - return VectorIndexReader::createFromMemory(index_info_iter->index_props().vector_index(), compressed); - }; - if (local_index_cache) - { - const auto key = fmt::format("{}{}", LocalIndexCache::COLUMNFILETINY_INDEX_NAME_PREFIX, index_page_id); - auto local_index = local_index_cache->getOrSet(key, load_from_page_storage); - vec_index = std::dynamic_pointer_cast(local_index); - } - else - vec_index = load_from_page_storage(); -} - -ColumnFileTinyVectorIndexReader::~ColumnFileTinyVectorIndexReader() -{ - LOG_DEBUG( - log, - "Finish vector search over column tiny_{}/{}(cid={}, rows={}){} cached, cost_[search/read]={:.3f}s/{:.3f}s " - "top_k_[query/visited/discarded/result]={}/{}/{}/{} ", - tiny_file.getDataPageId(), - vec_cd.name, - vec_cd.id, - tiny_file.getRows(), - perf_stat.load_from_cache ? "" : " not", - - perf_stat.load_vec_index_and_results_seconds, - perf_stat.read_vec_column_seconds, - - ann_query_info->top_k(), - perf_stat.visited_nodes, // Visited nodes will be larger than query_top_k when there are MVCC rows - perf_stat.discarded_nodes, // How many nodes are skipped by MVCC - perf_stat.returned_rows); -} - -std::vector ColumnFileTinyVectorIndexReader::loadVectorSearchResult() -{ - auto perf_begin = PerfContext::vector_search; - RUNTIME_CHECK(valid_rows.size() == tiny_file.getRows(), valid_rows.size(), tiny_file.getRows()); - - auto search_results = vec_index->search(ann_query_info, valid_rows); - // Sort by key - std::sort(search_results.begin(), search_results.end(), [](const auto & lhs, const auto & rhs) { - return lhs.key < rhs.key; - }); - // results must not contain duplicates. Usually there should be no duplicates. - search_results.erase( - std::unique( - search_results.begin(), - search_results.end(), - [](const auto & lhs, const auto & rhs) { return lhs.key == rhs.key; }), - search_results.end()); - - perf_stat.discarded_nodes = PerfContext::vector_search.discarded_nodes - perf_begin.discarded_nodes; - perf_stat.visited_nodes = PerfContext::vector_search.visited_nodes - perf_begin.visited_nodes; - return search_results; -} - -} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTinyVectorIndexReader.h b/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTinyVectorIndexReader.h deleted file mode 100644 index 42f30131854..00000000000 --- a/dbms/src/Storages/DeltaMerge/ColumnFile/ColumnFileTinyVectorIndexReader.h +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include - - -namespace DB::DM -{ - -class ColumnFileTinyVectorIndexReader -{ -private: - const ColumnFileTiny & tiny_file; - const IColumnFileDataProviderPtr data_provider; - - const ANNQueryInfoPtr ann_query_info; - // Set after load(). - VectorIndexReaderPtr vec_index; - const BitmapFilterView valid_rows; - // Note: ColumnDefine comes from read path does not have vector_index fields. - const ColumnDefine vec_cd; - // Global local index cache - const LocalIndexCachePtr local_index_cache; - LoggerPtr log; - - // Performance statistics - struct PerfStat - { - double load_vec_index_and_results_seconds = 0; - double read_vec_column_seconds = 0; - size_t discarded_nodes = 0; - size_t visited_nodes = 0; - size_t returned_rows = 0; - // Whether the vector index is loaded from cache. - bool load_from_cache = true; - }; - PerfStat perf_stat; - - // Whether the vector index and search results are loaded. - bool loaded = false; - -public: - ColumnFileTinyVectorIndexReader( - const ColumnFileTiny & tiny_file_, - const IColumnFileDataProviderPtr & data_provider_, - const ANNQueryInfoPtr & ann_query_info_, - const BitmapFilterView && valid_rows_, - const ColumnDefine & vec_cd_, - const LocalIndexCachePtr & local_index_cache_) - : tiny_file(tiny_file_) - , data_provider(data_provider_) - , ann_query_info(ann_query_info_) - , valid_rows(std::move(valid_rows_)) - , vec_cd(vec_cd_) - , local_index_cache(local_index_cache_) - , log(Logger::get()) - {} - - ~ColumnFileTinyVectorIndexReader(); - - // Read vector column data with the specified rowids. - void read( - MutableColumnPtr & vec_column, - const std::span & read_rowids, - size_t rowid_start_offset); - - // Load vector index and search results. - // Return the rowids of the selected rows. - std::vector load(); - -private: - void loadVectorIndex(); - std::vector loadVectorSearchResult(); -}; - -using ColumnFileTinyVectorIndexReaderPtr = std::shared_ptr; - -} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/ConcatSkippableBlockInputStream.cpp b/dbms/src/Storages/DeltaMerge/ConcatSkippableBlockInputStream.cpp index bbb75f2d2ac..9985789f977 100644 --- a/dbms/src/Storages/DeltaMerge/ConcatSkippableBlockInputStream.cpp +++ b/dbms/src/Storages/DeltaMerge/ConcatSkippableBlockInputStream.cpp @@ -16,28 +16,12 @@ #include #include -#include - namespace DB::DM { template ConcatSkippableBlockInputStream::ConcatSkippableBlockInputStream( - SkippableBlockInputStreams inputs_, - const ScanContextPtr & scan_context_) - : rows(inputs_.size(), 0) - , precede_stream_rows(0) - , scan_context(scan_context_) - , lac_bytes_collector(scan_context_ ? scan_context_->resource_group_name : "") -{ - assert(inputs_.size() == 1); // otherwise the `rows` is not correct - children.insert(children.end(), inputs_.begin(), inputs_.end()); - current_stream = children.begin(); -} - -template -ConcatSkippableBlockInputStream::ConcatSkippableBlockInputStream( - SkippableBlockInputStreams inputs_, + SkippableBlockInputStreams && inputs_, std::vector && rows_, const ScanContextPtr & scan_context_) : rows(std::move(rows_)) @@ -190,114 +174,4 @@ void ConcatSkippableBlockInputStream::addReadBytes(UInt64 bytes) template class ConcatSkippableBlockInputStream; template class ConcatSkippableBlockInputStream; -void ConcatVectorIndexBlockInputStream::load() -{ - if (loaded || topk == 0) - return; - - UInt32 precedes_rows = 0; - // otherwise the `row.key` of the search result is not correct - assert(stream->children.size() == index_streams.size()); - std::vector search_results; - for (size_t i = 0; i < stream->children.size(); ++i) - { - if (auto * index_stream = index_streams[i]; index_stream) - { - auto sr = index_stream->load(); - for (auto & row : sr) - row.key += precedes_rows; - search_results.insert(search_results.end(), sr.begin(), sr.end()); - } - precedes_rows += stream->rows[i]; - } - - // Keep the top k minimum distances rows. - const auto select_size = std::min(search_results.size(), topk); - auto top_k_end = search_results.begin() + select_size; - std::nth_element(search_results.begin(), top_k_end, search_results.end(), [](const auto & lhs, const auto & rhs) { - return lhs.distance < rhs.distance; - }); - std::vector selected_rows(select_size); - for (size_t i = 0; i < select_size; ++i) - selected_rows[i] = search_results[i].key; - // Sort by key again. - std::sort(selected_rows.begin(), selected_rows.end()); - - precedes_rows = 0; - auto sr_it = selected_rows.begin(); - for (size_t i = 0; i < stream->children.size(); ++i) - { - auto begin = std::lower_bound(sr_it, selected_rows.end(), precedes_rows); - auto end = std::lower_bound(begin, selected_rows.end(), precedes_rows + stream->rows[i]); - // Convert to local offset. - for (auto it = begin; it != end; ++it) - *it -= precedes_rows; - if (auto * index_stream = index_streams[i]; index_stream) - index_stream->setSelectedRows({begin, end}); - else - RUNTIME_CHECK(begin == end); - precedes_rows += stream->rows[i]; - sr_it = end; - } - - loaded = true; -} - -Block ConcatVectorIndexBlockInputStream::read() -{ - load(); - auto block = stream->read(); - if (!block) - return block; - - // The block read from `VectorIndexBlockInputStream` only return the selected rows. Return it directly. - // For streams which are not `VectorIndexBlockInputStream`, the block should be filtered by bitmap. - if (auto index = std::distance(stream->children.begin(), stream->current_stream); !index_streams[index]) - { - filter.resize(block.rows()); - if (bool all_match = bitmap_filter->get(filter, block.startOffset(), block.rows()); all_match) - return block; - - size_t passed_count = countBytesInFilter(filter); - for (auto & col : block) - { - col.column = col.column->filter(filter, passed_count); - } - } - - return block; -} - -std::tuple ConcatVectorIndexBlockInputStream::build( - const BitmapFilterPtr & bitmap_filter, - std::shared_ptr> stream, - const ANNQueryInfoPtr & ann_query_info) -{ - assert(ann_query_info != nullptr); - bool has_vector_index_stream = false; - std::vector index_streams; - index_streams.reserve(stream->children.size()); - for (const auto & sub_stream : stream->children) - { - if (auto * index_stream = dynamic_cast(sub_stream.get()); index_stream) - { - has_vector_index_stream = true; - index_streams.push_back(index_stream); - continue; - } - index_streams.push_back(nullptr); - } - if (!has_vector_index_stream) - return {stream, false}; - - return { - std::make_shared( - bitmap_filter, - stream, - std::move(index_streams), - ann_query_info->top_k()), - true, - }; -} - } // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/ConcatSkippableBlockInputStream.h b/dbms/src/Storages/DeltaMerge/ConcatSkippableBlockInputStream.h index 90442bfebdd..b6359e18d8a 100644 --- a/dbms/src/Storages/DeltaMerge/ConcatSkippableBlockInputStream.h +++ b/dbms/src/Storages/DeltaMerge/ConcatSkippableBlockInputStream.h @@ -15,9 +15,9 @@ #pragma once #include +#include #include #include -#include namespace DB::DM @@ -27,10 +27,19 @@ template class ConcatSkippableBlockInputStream : public SkippableBlockInputStream { public: - ConcatSkippableBlockInputStream(SkippableBlockInputStreams inputs_, const ScanContextPtr & scan_context_); + static auto create( + SkippableBlockInputStreams && inputs_, + std::vector && rows_, + const ScanContextPtr & scan_context_) + { + return std::make_shared>( + std::move(inputs_), + std::move(rows_), + scan_context_); + } ConcatSkippableBlockInputStream( - SkippableBlockInputStreams inputs_, + SkippableBlockInputStreams && inputs_, std::vector && rows_, const ScanContextPtr & scan_context_); @@ -49,7 +58,7 @@ class ConcatSkippableBlockInputStream : public SkippableBlockInputStream Block read() override; private: - friend class ConcatVectorIndexBlockInputStream; + friend class VectorIndexInputStream; ColumnPtr createSegmentRowIdCol(UInt64 start, UInt64 limit); void addReadBytes(UInt64 bytes); @@ -61,53 +70,4 @@ class ConcatSkippableBlockInputStream : public SkippableBlockInputStream LACBytesCollector lac_bytes_collector; }; -class ConcatVectorIndexBlockInputStream : public SkippableBlockInputStream -{ -public: - // Only return the rows that match `bitmap_filter_` - ConcatVectorIndexBlockInputStream( - const BitmapFilterPtr & bitmap_filter_, - std::shared_ptr> stream, - std::vector && index_streams, - UInt32 topk_) - : stream(std::move(stream)) - , index_streams(std::move(index_streams)) - , topk(topk_) - , bitmap_filter(bitmap_filter_) - {} - - // Returns - static std::tuple build( - const BitmapFilterPtr & bitmap_filter, - std::shared_ptr> stream, - const ANNQueryInfoPtr & ann_query_info); - - String getName() const override { return "ConcatVectorIndex"; } - - Block getHeader() const override { return stream->getHeader(); } - - bool getSkippedRows(size_t &) override { throw Exception("Not implemented", ErrorCodes::NOT_IMPLEMENTED); } - - size_t skipNextBlock() override { throw Exception("Not implemented", ErrorCodes::NOT_IMPLEMENTED); } - - Block readWithFilter(const IColumn::Filter &) override - { - throw Exception("Not implemented", ErrorCodes::NOT_IMPLEMENTED); - } - - Block read() override; - -private: - void load(); - - std::shared_ptr> stream; - // Pointers to stream's children, nullptr if the child is not a VectorIndexBlockInputStream. - std::vector index_streams; - UInt32 topk = 0; - bool loaded = false; - - BitmapFilterPtr bitmap_filter; - IColumn::Filter filter; // reuse the memory allocated among all `read` -}; - } // namespace DB::DM diff --git a/dbms/src/Storages/KVStore/tests/region_helper.h b/dbms/src/Storages/DeltaMerge/ConcatSkippableBlockInputStream_fwd.h similarity index 61% rename from dbms/src/Storages/KVStore/tests/region_helper.h rename to dbms/src/Storages/DeltaMerge/ConcatSkippableBlockInputStream_fwd.h index e3dcc6fb9e3..4bc09c80e2c 100644 --- a/dbms/src/Storages/KVStore/tests/region_helper.h +++ b/dbms/src/Storages/DeltaMerge/ConcatSkippableBlockInputStream_fwd.h @@ -1,4 +1,4 @@ -// Copyright 2023 PingCAP, Inc. +// Copyright 2025 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,16 +13,16 @@ // limitations under the License. #pragma once -#include -#include -#include +#include -namespace DB::tests +namespace DB::DM { -using DB::RegionBench::createPeer; -using DB::RegionBench::createRegionInfo; -using DB::RegionBench::createRegionMeta; -using DB::RegionBench::DebugRegion; -using DB::RegionBench::makeRegion; -} // namespace DB::tests + +template +class ConcatSkippableBlockInputStream; + +template +using ConcatSkippableBlockInputStreamPtr = std::shared_ptr>; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Decode/SSTFilesToBlockInputStream.h b/dbms/src/Storages/DeltaMerge/Decode/SSTFilesToBlockInputStream.h index f5d7749b518..1e55684f3ca 100644 --- a/dbms/src/Storages/DeltaMerge/Decode/SSTFilesToBlockInputStream.h +++ b/dbms/src/Storages/DeltaMerge/Decode/SSTFilesToBlockInputStream.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -27,8 +28,6 @@ namespace DB { class TMTContext; -class Region; -using RegionPtr = std::shared_ptr; struct SSTViewVec; struct TiFlashRaftProxyHelper; diff --git a/dbms/src/Storages/DeltaMerge/Decode/SSTFilesToDTFilesOutputStream.h b/dbms/src/Storages/DeltaMerge/Decode/SSTFilesToDTFilesOutputStream.h index 6f167c15ed6..477e6053382 100644 --- a/dbms/src/Storages/DeltaMerge/Decode/SSTFilesToDTFilesOutputStream.h +++ b/dbms/src/Storages/DeltaMerge/Decode/SSTFilesToDTFilesOutputStream.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -27,8 +28,6 @@ namespace DB { class TMTContext; -class Region; -using RegionPtr = std::shared_ptr; struct SSTViewVec; struct TiFlashRaftProxyHelper; diff --git a/dbms/src/Storages/DeltaMerge/Delta/ColumnFilePersistedSet.h b/dbms/src/Storages/DeltaMerge/Delta/ColumnFilePersistedSet.h index 784060f6ecc..d3290cd67b7 100644 --- a/dbms/src/Storages/DeltaMerge/Delta/ColumnFilePersistedSet.h +++ b/dbms/src/Storages/DeltaMerge/Delta/ColumnFilePersistedSet.h @@ -49,7 +49,7 @@ class ColumnFilePersistedSet , private boost::noncopyable { private: - PageIdU64 metadata_id; + const PageIdU64 metadata_id; ColumnFilePersisteds persisted_files; // TODO: check the proper memory_order when use this atomic variable std::atomic persisted_files_count = 0; diff --git a/dbms/src/Storages/DeltaMerge/Delta/DeltaValueSpace.cpp b/dbms/src/Storages/DeltaMerge/Delta/DeltaValueSpace.cpp index 08cdd89f55a..7ba6c89bb5c 100644 --- a/dbms/src/Storages/DeltaMerge/Delta/DeltaValueSpace.cpp +++ b/dbms/src/Storages/DeltaMerge/Delta/DeltaValueSpace.cpp @@ -336,18 +336,19 @@ bool DeltaValueSpace::ingestColumnFiles( bool DeltaValueSpace::flush(DMContext & context) { + String simple_info = simpleInfo(); bool v = false; if (!is_flushing.compare_exchange_strong(v, true)) { // other thread is flushing, just return. - LOG_DEBUG(log, "Flush stop because other thread is flushing, delta={}", simpleInfo()); + LOG_DEBUG(log, "Flush stop because other thread is flushing, delta={}", simple_info); return false; } SCOPE_EXIT({ bool v = true; if (!is_flushing.compare_exchange_strong(v, false)) throw Exception( - fmt::format("Delta is expected to be flushing, delta={}", simpleInfo()), + fmt::format("Delta is expected to be flushing, delta={}", simple_info), ErrorCodes::LOGICAL_ERROR); }); @@ -365,7 +366,7 @@ bool DeltaValueSpace::flush(DMContext & context) std::scoped_lock lock(mutex); if (abandoned.load(std::memory_order_relaxed)) { - LOG_DEBUG(log, "Flush stop because abandoned, delta={}", simpleInfo()); + LOG_DEBUG(log, "Flush stop because abandoned, delta={}", simple_info); return false; } flush_task = mem_table_set->buildFlushTask( @@ -379,7 +380,7 @@ bool DeltaValueSpace::flush(DMContext & context) // No update, return successfully. if (!flush_task) { - LOG_DEBUG(log, "Flush cancel because nothing to flush, delta={}", simpleInfo()); + LOG_DEBUG(log, "Flush cancel because nothing to flush, delta={}", simple_info); return true; } @@ -388,9 +389,9 @@ bool DeltaValueSpace::flush(DMContext & context) DeltaIndexPtr new_delta_index; if (!delta_index_updates.empty()) { - LOG_DEBUG(log, "Update index start, delta={}", simpleInfo()); + LOG_DEBUG(log, "Update index start, delta={}", simple_info); new_delta_index = cur_delta_index->cloneWithUpdates(delta_index_updates); - LOG_DEBUG(log, "Update index done, delta={}", simpleInfo()); + LOG_DEBUG(log, "Update index done, delta={}", simple_info); } GET_METRIC(tiflash_storage_subtask_throughput_bytes, type_delta_flush).Increment(flush_task->getFlushBytes()); GET_METRIC(tiflash_storage_subtask_throughput_rows, type_delta_flush).Increment(flush_task->getFlushRows()); @@ -404,14 +405,14 @@ bool DeltaValueSpace::flush(DMContext & context) { // Delete written data. wbs.setRollback(); - LOG_DEBUG(log, "Flush stop because abandoned, delta={}", simpleInfo()); + LOG_DEBUG(log, "Flush stop because abandoned, delta={}", simple_info); return false; } if (!flush_task->commit(persisted_file_set, wbs)) { wbs.rollbackWrittenLogAndData(); - LOG_DEBUG(log, "Flush stop because structure got updated, delta={}", simpleInfo()); + LOG_DEBUG(log, "Flush stop because structure got updated, delta={}", simple_info); return false; } diff --git a/dbms/src/Storages/DeltaMerge/DeltaMergeStore.cpp b/dbms/src/Storages/DeltaMerge/DeltaMergeStore.cpp index 469c96c8a37..d7c45c7744b 100644 --- a/dbms/src/Storages/DeltaMerge/DeltaMergeStore.cpp +++ b/dbms/src/Storages/DeltaMerge/DeltaMergeStore.cpp @@ -206,6 +206,22 @@ ColumnDefinesPtr generateStoreColumns(const ColumnDefines & table_columns, bool } return columns; } + +void convertStringTypeToDefault(DataTypePtr & type) +{ + if (removeNullable(type)->getTypeId() != TypeIndex::String) + return; + if (type->isNullable()) + type = DataTypeFactory::instance().getOrSet(DataTypeString::getNullableDefaultName()); + else + type = DataTypeFactory::instance().getOrSet(DataTypeString::getDefaultName()); +} + +void convertStringTypeToDefault(ColumnDefines & cds) +{ + for (auto & col : cds) + convertStringTypeToDefault(col.type); +} } // namespace DeltaMergeStore::Settings DeltaMergeStore::EMPTY_SETTINGS @@ -262,18 +278,17 @@ DeltaMergeStore::DeltaMergeStore( // Should be done before any background task setup. restoreStableFiles(); - original_table_columns.emplace_back(original_table_handle_define); - original_table_columns.emplace_back(getVersionColumnDefine()); - original_table_columns.emplace_back(getTagColumnDefine()); + ColumnDefines tmp_table_columns; + tmp_table_columns.emplace_back(original_table_handle_define); + tmp_table_columns.emplace_back(getVersionColumnDefine()); + tmp_table_columns.emplace_back(getTagColumnDefine()); for (const auto & col : columns) { if (col.id != original_table_handle_define.id && col.id != MutSup::version_col_id && col.id != MutSup::delmark_col_id) - original_table_columns.emplace_back(col); + tmp_table_columns.emplace_back(col); } - - original_table_header = std::make_shared(toEmptyBlock(original_table_columns)); - store_columns = generateStoreColumns(original_table_columns, is_common_handle); + updateColumnDefines(std::move(tmp_table_columns)); auto dm_context = newDMContext(db_context, db_context.getSettingsRef()); PageStorageRunMode page_storage_run_mode; @@ -2025,12 +2040,7 @@ void DeltaMergeStore::applySchemaChanges(TiDB::TableInfo & table_info) replica_exist.store(false); } - auto new_store_columns = generateStoreColumns(new_original_table_columns, is_common_handle); - - original_table_columns.swap(new_original_table_columns); - store_columns.swap(new_store_columns); - - std::atomic_store(&original_table_header, std::make_shared(toEmptyBlock(original_table_columns))); + updateColumnDefines(std::move(new_original_table_columns)); // release the lock because `applyLocalIndexChange` will try to acquire the lock // and generate tasks on segments @@ -2308,5 +2318,15 @@ void DeltaMergeStore::createFirstSegment(DM::DMContext & dm_context) addSegment(lock, first_segment); } +void DeltaMergeStore::updateColumnDefines(ColumnDefines && tmp_columns) +{ + // Tables created before the new string serialization format takes effect will + // not be automatically converted to the new type during restoration. + // Here, we force it to algin with default string type. + convertStringTypeToDefault(tmp_columns); + original_table_columns = std::move(tmp_columns); + store_columns = generateStoreColumns(original_table_columns, is_common_handle); + std::atomic_store(&original_table_header, std::make_shared(toEmptyBlock(original_table_columns))); +} } // namespace DM } // namespace DB diff --git a/dbms/src/Storages/DeltaMerge/DeltaMergeStore.h b/dbms/src/Storages/DeltaMerge/DeltaMergeStore.h index 27d7d45ad14..c4f70061707 100644 --- a/dbms/src/Storages/DeltaMerge/DeltaMergeStore.h +++ b/dbms/src/Storages/DeltaMerge/DeltaMergeStore.h @@ -949,6 +949,8 @@ class DeltaMergeStore bool throw_if_notfound); void createFirstSegment(DM::DMContext & dm_context); + void updateColumnDefines(ColumnDefines && tmp_columns); + Context & global_context; std::shared_ptr path_pool; Settings settings; diff --git a/dbms/src/Storages/DeltaMerge/DeltaMergeStore_InternalSegment.cpp b/dbms/src/Storages/DeltaMerge/DeltaMergeStore_InternalSegment.cpp index ed641561fb9..b261ad3bf6c 100644 --- a/dbms/src/Storages/DeltaMerge/DeltaMergeStore_InternalSegment.cpp +++ b/dbms/src/Storages/DeltaMerge/DeltaMergeStore_InternalSegment.cpp @@ -506,12 +506,12 @@ void DeltaMergeStore::checkAllSegmentsLocalIndex(std::vector && dropped size_t segments_missing_indexes = 0; - // 2. Trigger EnsureStableLocalIndex for all segments. + // 2. Trigger EnsureStableLocalIndex & EnsureDeltaLocalIndex for all segments. // There could be new segments between 1 and 2, which is fine. New segments - // will invoke EnsureStableLocalIndex at creation time. + // will invoke EnsureStableLocalIndex & EnsureDeltaLocalIndex at creation time. { // There must be a lock, because segments[] may be mutated. - // And one lock for all is fine, because segmentEnsureStableLocalIndexAsync is non-blocking, it + // And one lock for all is fine, because segmentEnsureStableLocalIndexAsync & segmentEnsureDeltaLocalIndexAsync is non-blocking, it // simply put tasks in the background. std::shared_lock lock(read_write_mutex); for (const auto & [end, segment] : segments) @@ -520,8 +520,9 @@ void DeltaMergeStore::checkAllSegmentsLocalIndex(std::vector && dropped // cleanup the index error message for dropped indexes segment->clearIndexBuildError(dropped_indexes); - if (segmentEnsureStableLocalIndexAsync(segment)) - ++segments_missing_indexes; + bool stable_missing_indexes = segmentEnsureStableLocalIndexAsync(segment); + bool delta_missing_indexes = segmentEnsureDeltaLocalIndexAsync(segment); + segments_missing_indexes += (stable_missing_indexes || delta_missing_indexes); } } @@ -573,11 +574,13 @@ bool DeltaMergeStore::segmentEnsureStableLocalIndexAsync(const SegmentPtr & segm { // new task of these index are generated, clear existing error_message in segment segment->clearIndexBuildError(build_info.indexesIDs()); - + auto file_ids = build_info.filesIDs(); + if (file_ids.empty()) + return true; auto [ok, reason] = indexer_scheduler->pushTask(LocalIndexerScheduler::Task{ .keyspace_id = keyspace_id, .table_id = physical_table_id, - .file_ids = build_info.filesIDs(), + .file_ids = file_ids, .request_memory = build_info.estimated_memory_bytes, .workload = workload, }); @@ -961,16 +964,24 @@ bool DeltaMergeStore::segmentWaitDeltaLocalIndexReady(const SegmentPtr & segment if (!column_file_persisted_set) return false; // ColumnFilePersistedSet is not exist, return false bool all_indexes_built = true; - for (const auto & index : *build_info.indexes_to_build) + auto delta_ptr = delta_weak_ptr.lock(); + if (auto lock = delta_ptr ? delta_ptr->getLock() : std::nullopt; lock) { - for (const auto & column_file : column_file_persisted_set->getFiles()) + const auto & column_files = column_file_persisted_set->getFiles(); + for (const auto & index : *build_info.indexes_to_build) { - auto * tiny_file = column_file->tryToTinyFile(); - if (!tiny_file) - continue; - all_indexes_built = all_indexes_built && (tiny_file->hasIndex(index.index_id)); + for (const auto & column_file : column_files) + { + if (auto * tiny_file = column_file->tryToTinyFile(); tiny_file) + all_indexes_built = all_indexes_built && (tiny_file->hasIndex(index.index_id)); + } } } + else + { + // Delta has been abandoned, return false + return false; + } if (all_indexes_built) return true; std::this_thread::sleep_for(std::chrono::milliseconds(100)); // 0.1s diff --git a/dbms/src/Storages/DeltaMerge/File/ColumnCache.cpp b/dbms/src/Storages/DeltaMerge/File/ColumnCache.cpp index 684ea503dd7..d1560bc01fb 100644 --- a/dbms/src/Storages/DeltaMerge/File/ColumnCache.cpp +++ b/dbms/src/Storages/DeltaMerge/File/ColumnCache.cpp @@ -142,26 +142,86 @@ void ColumnCache::tryPutColumn( }); } -ColumnCacheElement ColumnCache::getColumn(size_t pack_id, ColId column_id) +ColumnPtr ColumnCache::getColumn(size_t start_pack_id, size_t end_pack_id, size_t read_rows, ColId column_id) +{ + return column_caches.withShared([&](auto & column_caches) -> ColumnPtr { + auto iter = column_caches.find(start_pack_id); + RUNTIME_CHECK_MSG(iter != column_caches.end(), "Cannot find column in cache, start_pack_id={}", start_pack_id); + const auto & columns = iter->second.columns; + auto col_iter = columns.find(column_id); + RUNTIME_CHECK_MSG( + col_iter != columns.end(), + "Cannot find column in cache, pack_id={} column_id={}", + start_pack_id, + column_id); + const auto & column = col_iter->second; + // Optimization for some special cases: + // 1. The requested column is exactly the same as the cached column, return directly without copying. + if (iter->second.rows_offset == 0 && column->size() == read_rows) + return column; + // 2. The requested column is a subset of the cached column, cut the cached column and return. + assert(column->size() >= iter->second.rows_offset); + if (column->size() - iter->second.rows_offset >= read_rows) + return column->cut(iter->second.rows_offset, read_rows); + + // Otherwise, we need to copy data from multiple cached columns. + auto mut_col = column->cloneEmpty(); + getColumnImpl(column_caches, mut_col, start_pack_id, end_pack_id, read_rows, column_id); + return mut_col; + }); +} + +void ColumnCache::getColumn( + MutableColumnPtr & result, + size_t start_pack_id, + size_t end_pack_id, + size_t read_rows, + ColId column_id) { return column_caches.withShared([&](auto & column_caches) { - if (auto iter = column_caches.find(pack_id); iter != column_caches.end()) + getColumnImpl(column_caches, result, start_pack_id, end_pack_id, read_rows, column_id); + }); +} + +void ColumnCache::getColumnImpl( + const std::unordered_map & column_caches, + MutableColumnPtr & result, + size_t start_pack_id, + size_t end_pack_id, + size_t read_rows, + ColId column_id) +{ + size_t copied_rows = 0; + size_t processed_packs_rows = 0; + for (size_t cursor = start_pack_id; cursor < end_pack_id; ++cursor) + { + if (copied_rows >= read_rows) + break; + + auto iter = column_caches.find(cursor); + RUNTIME_CHECK_MSG(iter != column_caches.end(), "Cannot find column in cache, pack_id={}", cursor); + if (copied_rows > processed_packs_rows) { - auto & column_cache_entry = iter->second; - auto & columns = column_cache_entry.columns; - if (auto column_iter = columns.find(column_id); column_iter != columns.end()) - { - return std::make_pair( - column_iter->second, - std::make_pair(column_cache_entry.rows_offset, column_cache_entry.rows_count)); - } + // It could be that multiple cache_entries shared the same column ptr, and the rows has been copied + // in the previous loop. + processed_packs_rows += iter->second.rows_count; + continue; } - throw Exception( - ErrorCodes::LOGICAL_ERROR, - "Cannot find column in cache for pack id: {}, column id: {}", - pack_id, + const auto & columns = iter->second.columns; + auto col_iter = columns.find(column_id); + RUNTIME_CHECK_MSG( + col_iter != columns.end(), + "Cannot find column in cache, pack_id={} column_id={}", + start_pack_id, column_id); - }); + const auto & column = col_iter->second; + // Not that to_copied_rows could be larger than iter->second.rows_count, because the column ptr + // could be shared between multiple cache_entries. + size_t to_copied_rows = std::min(column->size() - iter->second.rows_offset, read_rows - copied_rows); + result->insertRangeFrom(*column, iter->second.rows_offset, to_copied_rows); + copied_rows += to_copied_rows; + processed_packs_rows += iter->second.rows_count; + } } void ColumnCache::delColumn(ColId column_id, size_t upper_pack_id) diff --git a/dbms/src/Storages/DeltaMerge/File/ColumnCache.h b/dbms/src/Storages/DeltaMerge/File/ColumnCache.h index 00f39246401..23afa321039 100644 --- a/dbms/src/Storages/DeltaMerge/File/ColumnCache.h +++ b/dbms/src/Storages/DeltaMerge/File/ColumnCache.h @@ -60,8 +60,16 @@ class ColumnCache : private boost::noncopyable void tryPutColumn(size_t pack_id, ColId column_id, const ColumnPtr & column, size_t rows_offset, size_t rows_count); - using ColumnCacheElement = std::pair>; - ColumnCacheElement getColumn(size_t pack_id, ColId column_id); + // Get column from cache, should make sure the column is in cache. + ColumnPtr getColumn(size_t start_pack_id, size_t end_pack_id, size_t read_rows, ColId column_id); + // Get column from cache, should make sure the column is in cache. + // Column data will append to `result`. + void getColumn( + MutableColumnPtr & result, + size_t start_pack_id, + size_t end_pack_id, + size_t read_rows, + ColId column_id); void delColumn(ColId column_id, size_t upper_pack_id); @@ -77,6 +85,14 @@ class ColumnCache : private boost::noncopyable ColId column_id, std::function is_hit); bool isPackInCache(PackId pack_id, ColId column_id); + struct ColumnCacheEntry; + static void getColumnImpl( + const std::unordered_map & column_caches, + MutableColumnPtr & result, + size_t start_pack_id, + size_t end_pack_id, + size_t read_rows, + ColId column_id); private: struct ColumnCacheEntry @@ -94,6 +110,5 @@ using ColumnCachePtr = std::shared_ptr; using ColumnCachePtrs = std::vector; using RangeWithStrategy = ColumnCache::RangeWithStrategy; using RangeWithStrategys = ColumnCache::RangeWithStrategys; -using ColumnCacheElement = ColumnCache::ColumnCacheElement; } // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/File/DMFile.h b/dbms/src/Storages/DeltaMerge/File/DMFile.h index 5cf8dce5d3f..fbb4d0526f6 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFile.h +++ b/dbms/src/Storages/DeltaMerge/File/DMFile.h @@ -38,7 +38,6 @@ int migrateServiceMain(DB::Context & context, const MigrateArgs & args); namespace DB::DM { -class DMFileWithVectorIndexBlockInputStream; namespace tests { class DMFileTest; @@ -313,6 +312,8 @@ class DMFile : private boost::noncopyable { case TiDB::ColumnarIndexKind::Vector: return fmt::format("idx_{}.vector", index_id); + case TiDB::ColumnarIndexKind::Inverted: + return fmt::format("idx_{}.inverted", index_id); default: throw Exception(fmt::format("Unsupported index kind: {}", magic_enum::enum_name(kind))); } @@ -344,7 +345,7 @@ class DMFile : private boost::noncopyable #endif DMFileMetaPtr meta; - friend class DMFileVectorIndexReader; + friend class VectorIndexReaderFromDMFile; friend class DMFileV3IncrementWriter; friend class DMFileWriter; friend class DMFileLocalIndexWriter; @@ -353,7 +354,7 @@ class DMFile : private boost::noncopyable friend class ColumnReadStream; friend class DMFilePackFilter; friend class DMFileBlockInputStreamBuilder; - friend class DMFileWithVectorIndexBlockInputStream; + friend class DMFileInputStreamProvideVectorIndex; friend class tests::DMFileTest; friend class tests::DMFileMetaV2Test; friend class tests::DMStoreForSegmentReadTaskTest; diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.cpp b/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.cpp index a43c625ebc5..8c2ed24641b 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.cpp +++ b/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.cpp @@ -14,8 +14,11 @@ #include #include -#include +#include #include +#include +#include +#include #include namespace DB::DM @@ -30,8 +33,6 @@ DMFileBlockInputStreamBuilder::DMFileBlockInputStreamBuilder(const Context & con setCaches( global_context.getMarkCache(), global_context.getMinMaxIndexCache(), - // DMFile vector index uses mmap to read data, does not directly occupy memory. - global_context.getLightLocalIndexCache(), global_context.getColumnCacheLongTerm()); // init from settings setFromSettings(context.getSettingsRef()); @@ -138,45 +139,38 @@ SkippableBlockInputStreamPtr DMFileBlockInputStreamBuilder::build( const RowKeyRanges & rowkey_ranges, const ScanContextPtr & scan_context) { - auto fallback = [&]() { - return buildNoLocalIndex(dmfile, read_columns, rowkey_ranges, scan_context); - }; - - if (!ann_query_info) - return fallback(); - - if (!bitmap_filter.has_value()) - return fallback(); - - Block header_layout = toEmptyBlock(read_columns); - - // Copy out the vector column for later use. Copy is intentionally performed after the - // fast check so that in fallback conditions we don't need unnecessary copies. - std::optional vec_column; - ColumnDefines rest_columns{}; - for (const auto & cd : read_columns) + if (vec_index_ctx) { - if (cd.id == ann_query_info->column_id()) - vec_column.emplace(cd); - else - rest_columns.emplace_back(cd); + // Note: this file may not have index built + return tryBuildWithVectorIndex(dmfile, read_columns, rowkey_ranges, scan_context); } - // No vector index column is specified, just use the normal logic. - if (!vec_column.has_value()) - return fallback(); + return buildNoLocalIndex(dmfile, read_columns, rowkey_ranges, scan_context); +} - RUNTIME_CHECK(rest_columns.size() + 1 == read_columns.size(), rest_columns.size(), read_columns.size()); +SkippableBlockInputStreamPtr DMFileBlockInputStreamBuilder::tryBuildWithVectorIndex( + const DMFilePtr & dmfile, + const ColumnDefines & read_columns, + const RowKeyRanges & rowkey_ranges, + const ScanContextPtr & scan_context) +{ + RUNTIME_CHECK(vec_index_ctx != nullptr); + RUNTIME_CHECK(read_columns.size() == vec_index_ctx->col_defs->size()); - const IndexID ann_query_info_index_id = ann_query_info->index_id() > 0 // - ? ann_query_info->index_id() - : EmptyIndexID; - if (!dmfile->isLocalIndexExist(vec_column->id, ann_query_info_index_id)) + auto fallback = [&]() { + vec_index_ctx->perf->n_from_dmf_noindex += 1; + return buildNoLocalIndex(dmfile, read_columns, rowkey_ranges, scan_context); + }; + + auto local_index = dmfile->getLocalIndex( // + vec_index_ctx->ann_query_info->column_id(), + vec_index_ctx->ann_query_info->index_id()); + if (!local_index.has_value()) // Vector index is defined but does not exist on the data file, // or there is no data at all return fallback(); - // All check passed. Let's read via vector index. + RUNTIME_CHECK(local_index->index_props().kind() == dtpb::IndexFileKind::VECTOR_INDEX); bool enable_read_thread = SegmentReaderPoolManager::instance().isSegmentReader(); bool is_common_handle = !rowkey_ranges.empty() && rowkey_ranges[0].is_common_handle; @@ -199,7 +193,7 @@ SkippableBlockInputStreamPtr DMFileBlockInputStreamBuilder::build( DMFileReader rest_columns_reader( dmfile, - rest_columns, + *vec_index_ctx->rest_col_defs, is_common_handle, enable_handle_clean_read, enable_del_clean_read, @@ -223,18 +217,11 @@ SkippableBlockInputStreamPtr DMFileBlockInputStreamBuilder::build( // ColumnCacheLongTerm is only filled in Vector Search. rest_columns_reader.setColumnCacheLongTerm(column_cache_long_term, pk_col_id); - DMFileWithVectorIndexBlockInputStreamPtr reader = DMFileWithVectorIndexBlockInputStream::create( - ann_query_info, + vec_index_ctx->perf->n_from_dmf_index += 1; + return DMFileInputStreamProvideVectorIndex::create( // + vec_index_ctx, dmfile, - std::move(header_layout), - std::move(rest_columns_reader), - std::move(vec_column.value()), - scan_context, - local_index_cache, - bitmap_filter.value(), - tracing_id); - - return reader; + std::move(rest_columns_reader)); } } // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.h b/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.h index c0d084b4991..8d00d53f4ec 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.h +++ b/dbms/src/Storages/DeltaMerge/File/DMFileBlockInputStream.h @@ -17,14 +17,11 @@ #include #include #include -#include #include #include #include #include -#include -#include -#include +#include #include #include #include @@ -118,15 +115,9 @@ class DMFileBlockInputStreamBuilder return *this; } - DMFileBlockInputStreamBuilder & setBitmapFilter(const BitmapFilterView & bitmap_filter_) + DMFileBlockInputStreamBuilder & setVecIndexQuery(const VectorIndexStreamCtxPtr & ctx) { - bitmap_filter.emplace(bitmap_filter_); - return *this; - } - - DMFileBlockInputStreamBuilder & setAnnQureyInfo(const ANNQueryInfoPtr & ann_query_info_) - { - ann_query_info = ann_query_info_; + vec_index_ctx = ctx; return *this; } @@ -196,6 +187,12 @@ class DMFileBlockInputStreamBuilder const RowKeyRanges & rowkey_ranges, const ScanContextPtr & scan_context); + SkippableBlockInputStreamPtr tryBuildWithVectorIndex( + const DMFilePtr & dmfile, + const ColumnDefines & read_columns, + const RowKeyRanges & rowkey_ranges, + const ScanContextPtr & scan_context); + private: // These methods are called by the ctor @@ -205,12 +202,10 @@ class DMFileBlockInputStreamBuilder DMFileBlockInputStreamBuilder & setCaches( const MarkCachePtr & mark_cache_, const MinMaxIndexCachePtr & index_cache_, - const LocalIndexCachePtr & local_index_cache_, const ColumnCacheLongTermPtr & column_cache_long_term_) { mark_cache = mark_cache_; index_cache = index_cache_; - local_index_cache = local_index_cache_; column_cache_long_term = column_cache_long_term_; return *this; } @@ -241,11 +236,9 @@ class DMFileBlockInputStreamBuilder DMFilePackFilterResultPtr pack_filter; - ANNQueryInfoPtr ann_query_info = nullptr; - - LocalIndexCachePtr local_index_cache; - // Note: Currently thie field is assigned only for Stable streams, not available for ColumnFileBig - std::optional bitmap_filter; + /// If set, will *try* to build a VectorIndexDMFileInputStream + /// instead of a normal DMFileBlockInputStream. + VectorIndexStreamCtxPtr vec_index_ctx = nullptr; // Note: column_cache_long_term is currently only filled when performing Vector Search. ColumnCacheLongTermPtr column_cache_long_term = nullptr; diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileLocalIndexWriter.cpp b/dbms/src/Storages/DeltaMerge/File/DMFileLocalIndexWriter.cpp index 9d33d1d8f8f..6e0abfdfd78 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileLocalIndexWriter.cpp +++ b/dbms/src/Storages/DeltaMerge/File/DMFileLocalIndexWriter.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include @@ -20,18 +21,18 @@ #include #include #include -#include +#include #include -#include #include -#include #include + namespace DB::ErrorCodes { extern const int ABORTED; } + namespace DB::FailPoints { extern const char exception_build_local_index_for_file[]; @@ -116,7 +117,7 @@ size_t DMFileLocalIndexWriter::buildIndexForFile(const DMFilePtr & dm_file_mutab LocalIndexInfo info; String index_file_path; // For write out String index_file_name; // For meta include - VectorIndexWriterOnDiskPtr builder_vector; + LocalIndexWriterOnDiskPtr index_writer; }; std::unordered_map> index_builders; @@ -127,7 +128,7 @@ size_t DMFileLocalIndexWriter::buildIndexForFile(const DMFilePtr & dm_file_mutab .info = index_info, .index_file_path = "", .index_file_name = "", - .builder_vector = {}, + .index_writer = {}, }); } @@ -156,17 +157,7 @@ size_t DMFileLocalIndexWriter::buildIndexForFile(const DMFilePtr & dm_file_mutab index.info.column_id, index.info.index_id); - switch (index.info.kind) - { - case TiDB::ColumnarIndexKind::Vector: - index.builder_vector = VectorIndexWriterOnDisk::create( // - index.index_file_path, - index.info.def_vector_index); - break; - default: - RUNTIME_CHECK_MSG(false, "Unsupported index kind: {}", magic_enum::enum_name(index.info.kind)); - break; - } + index.index_writer = LocalIndexWriter::createOnDisk(index.index_file_path, index.info); } read_columns.push_back(*cd_iter); } @@ -213,16 +204,8 @@ size_t DMFileLocalIndexWriter::buildIndexForFile(const DMFilePtr & dm_file_mutab const auto & col = col_with_type_and_name.column; for (const auto & index : index_builders[read_columns[col_idx].id]) { - switch (index.info.kind) - { - case TiDB::ColumnarIndexKind::Vector: - RUNTIME_CHECK(index.builder_vector); - index.builder_vector->addBlock(*col, del_mark, should_proceed); - break; - default: - RUNTIME_CHECK_MSG(false, "Unsupported index kind: {}", magic_enum::enum_name(index.info.kind)); - break; - } + RUNTIME_CHECK(index.index_writer); + index.index_writer->addBlock(*col, del_mark, should_proceed); } } } @@ -240,33 +223,10 @@ size_t DMFileLocalIndexWriter::buildIndexForFile(const DMFilePtr & dm_file_mutab for (const auto & index : index_builders[cd.id]) { - dtpb::DMFileIndexInfo pb_dmfile_idx{}; - auto * pb_idx = pb_dmfile_idx.mutable_index_props(); - - switch (index.info.kind) - { - case TiDB::ColumnarIndexKind::Vector: - { - index.builder_vector->finalize(); - auto * pb_vec_idx = pb_idx->mutable_vector_index(); - pb_vec_idx->set_format_version(0); - pb_vec_idx->set_dimensions(index.info.def_vector_index->dimension); - pb_vec_idx->set_distance_metric( - tipb::VectorDistanceMetric_Name(index.info.def_vector_index->distance_metric)); - break; - } - default: - RUNTIME_CHECK_MSG(false, "Unsupported index kind: {}", magic_enum::enum_name(index.info.kind)); - break; - } - - auto index_file = Poco::File(index.index_file_path); - RUNTIME_CHECK(index_file.exists()); - pb_idx->set_kind(index.info.getKindAsDtpb()); - pb_idx->set_index_id(index.info.index_id); - pb_idx->set_file_size(index_file.getSize()); - - total_built_index_bytes += pb_idx->file_size(); + dtpb::DMFileIndexInfo pb_dmfile_idx; + auto idx_info = index.index_writer->finalize(); + pb_dmfile_idx.mutable_index_props()->Swap(&idx_info); + total_built_index_bytes += pb_dmfile_idx.index_props().file_size(); new_indexes.emplace_back(std::move(pb_dmfile_idx)); iw->include(index.index_file_name); } diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileLocalIndexWriter.h b/dbms/src/Storages/DeltaMerge/File/DMFileLocalIndexWriter.h index 5768f74f409..930236a03b0 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileLocalIndexWriter.h +++ b/dbms/src/Storages/DeltaMerge/File/DMFileLocalIndexWriter.h @@ -45,6 +45,8 @@ struct LocalIndexBuildInfo ids.reserve(dm_files.size()); for (const auto & dmf : dm_files) { + if (unlikely(dmf->getRows() == 0)) + continue; ids.emplace_back(LocalIndexerScheduler::DMFileID(dmf->fileId())); } return ids; diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileReader.cpp b/dbms/src/Storages/DeltaMerge/File/DMFileReader.cpp index 7d162d0f598..573fd348a06 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileReader.cpp +++ b/dbms/src/Storages/DeltaMerge/File/DMFileReader.cpp @@ -635,29 +635,36 @@ ColumnPtr DMFileReader::getColumnFromCache( const auto col_id = cd.id; auto read_strategy = data_cache->getReadStrategy(start_pack_id, pack_count, col_id); + if (read_strategy.size() == 1) + { + auto [range, strategy] = read_strategy.front(); + if (strategy == ColumnCache::Strategy::Memory) + { + return data_cache->getColumn(range.first, range.second, read_rows, col_id); + } + else if (strategy == ColumnCache::Strategy::Disk) + { + return on_cache_miss(cd, type_on_disk, range.first, range.second - range.first, read_rows); + } + + throw Exception("Unknown strategy", ErrorCodes::LOGICAL_ERROR); + } + const auto & pack_stats = dmfile->getPackStats(); auto mutable_col = type_on_disk->createColumn(); - mutable_col->reserve(read_rows); for (auto & [range, strategy] : read_strategy) { + size_t rows_count = 0; + for (size_t cursor = range.first; cursor < range.second; ++cursor) + { + rows_count += pack_stats[cursor].rows; + } if (strategy == ColumnCache::Strategy::Memory) { - for (size_t cursor = range.first; cursor < range.second; ++cursor) - { - auto cache_element = data_cache->getColumn(cursor, col_id); - mutable_col->insertRangeFrom( - *(cache_element.first), - cache_element.second.first, - cache_element.second.second); - } + data_cache->getColumn(mutable_col, range.first, range.second, rows_count, col_id); } else if (strategy == ColumnCache::Strategy::Disk) { - size_t rows_count = 0; - for (size_t cursor = range.first; cursor < range.second; cursor++) - { - rows_count += pack_stats[cursor].rows; - } auto sub_col = on_cache_miss(cd, type_on_disk, range.first, range.second - range.first, rows_count); mutable_col->insertRangeFrom(*sub_col, 0, sub_col->size()); } diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileReader.h b/dbms/src/Storages/DeltaMerge/File/DMFileReader.h index bb477865c56..9967a5d107b 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileReader.h +++ b/dbms/src/Storages/DeltaMerge/File/DMFileReader.h @@ -31,12 +31,10 @@ namespace DB::DM { -class DMFileWithVectorIndexBlockInputStream; - class DMFileReader { - friend class DMFileWithVectorIndexBlockInputStream; + friend class DMFileInputStreamProvideVectorIndex; friend class DMFileReaderPoolSharding; public: diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileVectorIndexReader.cpp b/dbms/src/Storages/DeltaMerge/File/DMFileVectorIndexReader.cpp deleted file mode 100644 index c7dc10ad75b..00000000000 --- a/dbms/src/Storages/DeltaMerge/File/DMFileVectorIndexReader.cpp +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace DB::ErrorCodes -{ -extern const int S3_ERROR; -} // namespace DB::ErrorCodes - -namespace DB::DM -{ - -String DMFileVectorIndexReader::PerfStat::toString() const -{ - return fmt::format( - "index_size={} load={:.2f}s{}{}, search={:.2f}s, read_vec_column={:.2f}s", - index_size, - duration_load_index, - has_s3_download ? " (S3)" : "", - has_load_from_file ? " (LoadFile)" : "", - duration_search, - duration_read_vec_column); -} - -std::vector DMFileVectorIndexReader::load() -{ - if (loaded) - return {}; - - loadVectorIndex(); - auto sorted_results = loadVectorSearchResult(); - - perf_stat.selected_nodes = sorted_results.size(); - loaded = true; - return sorted_results; -} - -void DMFileVectorIndexReader::loadVectorIndex() -{ - const auto col_id = ann_query_info->column_id(); - const auto index_id = ann_query_info->index_id() > 0 ? ann_query_info->index_id() : EmptyIndexID; - - RUNTIME_CHECK(dmfile->useMetaV2()); // v3 - - // Check vector index exists on the column - auto vector_index = dmfile->getLocalIndex(col_id, index_id); - RUNTIME_CHECK(vector_index.has_value(), col_id, index_id); - RUNTIME_CHECK(vector_index->index_props().kind() == dtpb::IndexFileKind::VECTOR_INDEX); - RUNTIME_CHECK(vector_index->index_props().has_vector_index()); - perf_stat.index_size = vector_index->index_props().file_size(); - - // If local file is invalidated, cache is not valid anymore. So we - // need to ensure file exists on local fs first. - const auto index_file_path = index_id > 0 // - ? dmfile->localIndexPath(index_id, TiDB::ColumnarIndexKind::Vector) // - : dmfile->colIndexPath(DMFile::getFileNameBase(col_id)); - String local_index_file_path; - if (auto s3_file_name = S3::S3FilenameView::fromKeyWithPrefix(index_file_path); s3_file_name.isValid()) - { - // Disaggregated mode - auto * file_cache = FileCache::instance(); - RUNTIME_CHECK_MSG(file_cache, "Must enable S3 file cache to use vector index"); - - Stopwatch watch; - - auto perf_begin = PerfContext::file_cache; - - // If download file failed, retry a few times. - for (auto i = 3; i > 0; --i) - { - try - { - if (auto file_guard = file_cache->downloadFileForLocalRead( // - s3_file_name, - vector_index->index_props().file_size()); - file_guard) - { - local_index_file_path = file_guard->getLocalFileName(); - break; // Successfully downloaded index into local cache - } - - throw Exception(ErrorCodes::S3_ERROR, "Failed to download vector index file {}", index_file_path); - } - catch (...) - { - if (i <= 1) - throw; - } - } - - if ( // - PerfContext::file_cache.fg_download_from_s3 > perf_begin.fg_download_from_s3 || // - PerfContext::file_cache.fg_wait_download_from_s3 > perf_begin.fg_wait_download_from_s3) - perf_stat.has_s3_download = true; - - auto download_duration = watch.elapsedSeconds(); - perf_stat.duration_load_index += download_duration; - - GET_METRIC(tiflash_vector_index_duration, type_download).Observe(download_duration); - } - else - { - // Not disaggregated mode - local_index_file_path = index_file_path; - } - - auto load_from_file = [&]() { - perf_stat.has_load_from_file = true; - return VectorIndexReader::createFromMmap(vector_index->index_props().vector_index(), local_index_file_path); - }; - - Stopwatch watch; - if (local_index_cache) - { - // Note: must use local_index_file_path as the cache key, because cache - // will check whether file is still valid and try to remove memory references - // when file is dropped. - auto local_index = local_index_cache->getOrSet(local_index_file_path, load_from_file); - vec_index = std::dynamic_pointer_cast(local_index); - } - else - { - vec_index = load_from_file(); - } - - perf_stat.duration_load_index += watch.elapsedSeconds(); - RUNTIME_CHECK(vec_index != nullptr); - - scan_context->total_vector_idx_load_time_ms += static_cast(perf_stat.duration_load_index * 1000); - if (perf_stat.has_s3_download) - // it could be possible that s3=true but load_from_file=false, it means we download a file - // and then reuse the memory cache. The majority time comes from s3 download - // so we still count it as s3 download. - scan_context->total_vector_idx_load_from_s3++; - else if (perf_stat.has_load_from_file) - scan_context->total_vector_idx_load_from_disk++; - else - scan_context->total_vector_idx_load_from_cache++; -} - -DMFileVectorIndexReader::~DMFileVectorIndexReader() -{ - scan_context->total_vector_idx_read_vec_time_ms += static_cast(perf_stat.duration_read_vec_column * 1000); -} - -String DMFileVectorIndexReader::perfStat() const -{ - return fmt::format( - "{} top_k_[query/visited/discarded/result]={}/{}/{}/{}", - perf_stat.toString(), - ann_query_info->top_k(), - perf_stat.visited_nodes, - perf_stat.discarded_nodes, - perf_stat.selected_nodes); -} - -std::vector DMFileVectorIndexReader::loadVectorSearchResult() -{ - Stopwatch watch; - - auto perf_begin = PerfContext::vector_search; - - RUNTIME_CHECK(valid_rows.size() >= dmfile->getRows(), valid_rows.size(), dmfile->getRows()); - auto search_results = vec_index->search(ann_query_info, valid_rows); - // Sort by key - std::sort(search_results.begin(), search_results.end(), [](const auto & lhs, const auto & rhs) { - return lhs.key < rhs.key; - }); - // results must not contain duplicates. Usually there should be no duplicates. - search_results.erase( - std::unique( - search_results.begin(), - search_results.end(), - [](const auto & lhs, const auto & rhs) { return lhs.key == rhs.key; }), - search_results.end()); - - perf_stat.discarded_nodes = PerfContext::vector_search.discarded_nodes - perf_begin.discarded_nodes; - perf_stat.visited_nodes = PerfContext::vector_search.visited_nodes - perf_begin.visited_nodes; - - perf_stat.duration_search = watch.elapsedSeconds(); - scan_context->total_vector_idx_search_time_ms += static_cast(perf_stat.duration_search * 1000); - scan_context->total_vector_idx_search_discarded_nodes += perf_stat.discarded_nodes; - scan_context->total_vector_idx_search_visited_nodes += perf_stat.visited_nodes; - - return search_results; -} - -void DMFileVectorIndexReader::read( - MutableColumnPtr & vec_column, - const std::span & selected_rows) -{ - Stopwatch watch; - RUNTIME_CHECK(loaded); - - vec_column->reserve(selected_rows.size()); - std::vector value; - for (auto rowid : selected_rows) - { - vec_index->get(rowid, value); - vec_column->insertData(reinterpret_cast(value.data()), value.size() * sizeof(Float32)); - } - - perf_stat.duration_read_vec_column += watch.elapsedSeconds(); -} - -} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileVectorIndexReader.h b/dbms/src/Storages/DeltaMerge/File/DMFileVectorIndexReader.h deleted file mode 100644 index 34fb4ef8691..00000000000 --- a/dbms/src/Storages/DeltaMerge/File/DMFileVectorIndexReader.h +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include - -namespace DB::DM -{ - -class DMFileVectorIndexReader -{ -private: - const DMFilePtr & dmfile; - const ANNQueryInfoPtr & ann_query_info; - const BitmapFilterView valid_rows; - const ScanContextPtr & scan_context; - // Global local index cache - const LocalIndexCachePtr local_index_cache; - - // Performance statistics - struct PerfStat - { - double duration_search; - double duration_load_index; - double duration_read_vec_column; - size_t index_size; - size_t visited_nodes; - size_t discarded_nodes; - size_t selected_nodes; - bool has_s3_download; - bool has_load_from_file; - - String toString() const; - }; - PerfStat perf_stat; - - // Set after load(). - VectorIndexReaderPtr vec_index = nullptr; - bool loaded = false; - -public: - DMFileVectorIndexReader( - const ANNQueryInfoPtr & ann_query_info_, - const DMFilePtr & dmfile_, - const BitmapFilterView & valid_rows_, - const ScanContextPtr & scan_context_, - const LocalIndexCachePtr & local_index_cache_) - : dmfile(dmfile_) - , ann_query_info(ann_query_info_) - , valid_rows(valid_rows_) - , scan_context(scan_context_) - , local_index_cache(local_index_cache_) - , perf_stat() - {} - - ~DMFileVectorIndexReader(); - - // Read vector column data with the specified rowids. - void read(MutableColumnPtr & vec_column, const std::span & selected_rows); - - // Load vector index and search results. - // Return the rowids of the selected rows. - std::vector load(); - - String perfStat() const; - -private: - void loadVectorIndex(); - std::vector loadVectorSearchResult(); -}; - -using DMFileVectorIndexReaderPtr = std::shared_ptr; - -} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.cpp b/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.cpp deleted file mode 100644 index e8ed1c86bd2..00000000000 --- a/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.cpp +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -namespace DB::DM -{ - -DMFileWithVectorIndexBlockInputStream::DMFileWithVectorIndexBlockInputStream( - const ANNQueryInfoPtr & ann_query_info_, - const DMFilePtr & dmfile_, - Block && header_, - DMFileReader && reader_, - ColumnDefine && vec_cd_, - const ScanContextPtr & scan_context_, - const LocalIndexCachePtr & local_index_cache_, - const BitmapFilterView & valid_rows_, - const String & tracing_id) - : log(Logger::get(tracing_id)) - , ann_query_info(ann_query_info_) - , dmfile(dmfile_) - , header(std::move(header_)) - , reader(std::move(reader_)) - , vec_cd(std::move(vec_cd_)) - , scan_context(scan_context_) - , vec_index_reader(std::make_shared( - ann_query_info, - dmfile, - valid_rows_, - scan_context, - local_index_cache_)) -{} - -DMFileWithVectorIndexBlockInputStream::~DMFileWithVectorIndexBlockInputStream() -{ - scan_context->total_vector_idx_read_others_time_ms - += static_cast(duration_read_from_other_columns_seconds * 1000); - - LOG_DEBUG( - log, - "Finished vector search over column dmf_{}/{}(id={}), index_id={} {} " - "pack_[total/before_search/after_search]={}/{}/{}", - dmfile->fileId(), - vec_cd.name, - vec_cd.id, - ann_query_info->index_id(), - - vec_index_reader->perfStat(), - - dmfile->getPackStats().size(), - valid_packs_before_search, - valid_packs_after_search); -} - -Block DMFileWithVectorIndexBlockInputStream::read() -{ - internalLoad(); - - if (reader.read_block_infos.empty()) - return {}; - - const auto [start_pack_id, pack_count, rs_result, read_rows] = reader.read_block_infos.front(); - const auto start_row_offset = reader.pack_offset[start_pack_id]; - - auto vec_column = vec_cd.type->createColumn(); - auto begin = std::lower_bound(sorted_results.cbegin(), sorted_results.cend(), start_row_offset); - auto end = std::lower_bound(begin, sorted_results.cend(), start_row_offset + read_rows); - const std::span block_selected_rows{begin, end}; - if (block_selected_rows.empty()) - return {}; - - // read vector column - vec_index_reader->read(vec_column, block_selected_rows); - - Block block; - - // read other columns if needed - if (!reader.read_columns.empty()) - { - Stopwatch w; - - filter.clear(); - filter.resize_fill(read_rows, 0); - for (const auto rowid : block_selected_rows) - filter[rowid - start_row_offset] = 1; - - block = reader.read(); - for (auto & col : block) - col.column = col.column->filter(filter, block_selected_rows.size()); - duration_read_from_other_columns_seconds += w.elapsedSeconds(); - } - else - { - // Since we do not call `reader.read()` here, we need to pop the read_block_infos manually. - reader.read_block_infos.pop_front(); - } - - auto index = header.getPositionByName(vec_cd.name); - block.insert(index, ColumnWithTypeAndName{std::move(vec_column), vec_cd.type, vec_cd.name, vec_cd.id}); - - block.setStartOffset(start_row_offset); - block.setRSResult(rs_result); - return block; -} - -std::vector DMFileWithVectorIndexBlockInputStream::load() -{ - if (loaded) - return {}; - - auto search_results = vec_index_reader->load(); - return search_results; -} - -void DMFileWithVectorIndexBlockInputStream::internalLoad() -{ - if (loaded) - return; - - auto search_results = vec_index_reader->load(); - sorted_results.reserve(search_results.size()); - for (const auto & row : search_results) - sorted_results.push_back(row.key); - - updateReadBlockInfos(); -} - -void DMFileWithVectorIndexBlockInputStream::updateReadBlockInfos() -{ - // Vector index is very likely to filter out some packs. For example, - // if we query for Top 1, then only 1 pack will be remained. So we - // update the reader's read_block_infos to avoid reading unnecessary data for other columns. - - // The following logic is nearly the same with DMFileReader::initReadBlockInfos. - - auto & read_block_infos = reader.read_block_infos; - const auto & pack_offset = reader.pack_offset; - - read_block_infos.clear(); - const auto & pack_stats = dmfile->getPackStats(); - const auto & pack_res = reader.pack_filter->getPackRes(); - - // Update valid_packs_before_search - for (const auto res : pack_res) - valid_packs_before_search += res.isUse(); - - // Update read_block_infos - size_t start_pack_id = 0; - size_t read_rows = 0; - auto prev_block_pack_res = RSResult::All; - auto sorted_results_it = sorted_results.cbegin(); - size_t pack_id = 0; - for (; pack_id < pack_stats.size(); ++pack_id) - { - if (sorted_results_it == sorted_results.cend()) - break; - auto begin = std::lower_bound(sorted_results_it, sorted_results.cend(), pack_offset[pack_id]); - auto end = std::lower_bound(begin, sorted_results.cend(), pack_offset[pack_id] + pack_stats[pack_id].rows); - bool is_use = begin != end; - bool reach_limit = read_rows >= reader.rows_threshold_per_read; - bool break_all_match = prev_block_pack_res.allMatch() && !pack_res[pack_id].allMatch() - && read_rows >= reader.rows_threshold_per_read / 2; - - if (!is_use) - { - if (read_rows > 0) - read_block_infos.emplace_back(start_pack_id, pack_id - start_pack_id, prev_block_pack_res, read_rows); - start_pack_id = pack_id + 1; - read_rows = 0; - prev_block_pack_res = RSResult::All; - } - else if (reach_limit || break_all_match) - { - if (read_rows > 0) - read_block_infos.emplace_back(start_pack_id, pack_id - start_pack_id, prev_block_pack_res, read_rows); - start_pack_id = pack_id; - read_rows = pack_stats[pack_id].rows; - prev_block_pack_res = pack_res[pack_id]; - } - else - { - prev_block_pack_res = prev_block_pack_res && pack_res[pack_id]; - read_rows += pack_stats[pack_id].rows; - } - - sorted_results_it = end; - } - if (read_rows > 0) - read_block_infos.emplace_back(start_pack_id, pack_id - start_pack_id, prev_block_pack_res, read_rows); - - // Update valid_packs_after_search - for (const auto & block_info : read_block_infos) - valid_packs_after_search += block_info.pack_count; - - RUNTIME_CHECK_MSG(sorted_results_it == sorted_results.cend(), "All results are not consumed"); - loaded = true; -} - -void DMFileWithVectorIndexBlockInputStream::setSelectedRows(const std::span & selected_rows) -{ - sorted_results.clear(); - sorted_results.reserve(selected_rows.size()); - std::copy(selected_rows.begin(), selected_rows.end(), std::back_inserter(sorted_results)); - - updateReadBlockInfos(); -} - -} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.h b/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.h deleted file mode 100644 index 8d0a2ca6a26..00000000000 --- a/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.h +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - - -namespace DB::DM -{ - -/** - * @brief DMFileWithVectorIndexBlockInputStream is similar to DMFileBlockInputStream. - * However it can read data efficiently with the help of vector index. - * - * General steps: - * 1. Read all PK, Version and Del Marks (respecting Pack filters). - * 2. Construct a bitmap of valid rows (in memory). This bitmap guides the reading of vector index to determine whether a row is valid or not. - * - * Note: Step 1 and 2 simply rely on the BitmapFilter to avoid repeat IOs. - * BitmapFilter is global, which provides row valid info for all DMFile + Delta. - * What we need is which rows are valid in THIS DMFile. - * To transform a global BitmapFilter result into a local one, RowOffsetTracker is used. - * - * 3. Perform a vector search for Top K vector rows. We now have K row_ids whose vector distance is close. - * 4. Map these row_ids to packids as the new pack filter. - * 5. Read from other columns with the new pack filter. - * For each read, join other columns and vector column together. - * - * Step 3~4 is performed lazily at first read. - * - * Before constructing this class, the caller must ensure that vector index - * exists on the corresponding column. If the index does not exist, the caller - * should use the standard DMFileBlockInputStream. - */ -class DMFileWithVectorIndexBlockInputStream : public VectorIndexBlockInputStream -{ -public: - static DMFileWithVectorIndexBlockInputStreamPtr create( - const ANNQueryInfoPtr & ann_query_info, - const DMFilePtr & dmfile, - Block && header, - DMFileReader && reader, - ColumnDefine && vec_cd, - const ScanContextPtr & scan_context, - const LocalIndexCachePtr & local_index_cache, - const BitmapFilterView & valid_rows, - const String & tracing_id) - { - return std::make_shared( - ann_query_info, - dmfile, - std::move(header), - std::move(reader), - std::move(vec_cd), - scan_context, - local_index_cache, - valid_rows, - tracing_id); - } - - explicit DMFileWithVectorIndexBlockInputStream( - const ANNQueryInfoPtr & ann_query_info_, - const DMFilePtr & dmfile_, - Block && header_, - DMFileReader && reader_, - ColumnDefine && vec_cd_, - const ScanContextPtr & scan_context_, - const LocalIndexCachePtr & local_index_cache_, - const BitmapFilterView & valid_rows_, - const String & tracing_id); - - ~DMFileWithVectorIndexBlockInputStream() override; - -public: - Block read() override; - - String getName() const override { return "DMFileWithVectorIndex"; } - - Block getHeader() const override { return header; } - - std::vector load() override; - - void setSelectedRows(const std::span & selected_rows) override; - -private: - // Load vector index and update sorted_results. - void internalLoad(); - - // Update the read_block_infos according to the sorted_results. - void updateReadBlockInfos(); - -private: - const LoggerPtr log; - - const ANNQueryInfoPtr ann_query_info; - const DMFilePtr dmfile; - - // The header contains columns from reader and vec_cd - Block header; - // Vector column should be excluded in the reader - DMFileReader reader; - // Note: ColumnDefine comes from read path does not have vector_index fields. - const ColumnDefine vec_cd; - const ScanContextPtr scan_context; - const DMFileVectorIndexReaderPtr vec_index_reader; - - // Set after load(). - VectorIndexReaderPtr vec_index = nullptr; - // VectorColumnFromIndexReaderPtr vec_column_reader = nullptr; - // Set after load(). Used to filter the output rows. - std::vector sorted_results{}; // Key is rowid - IColumn::Filter filter; - - bool loaded = false; - - double duration_read_from_other_columns_seconds = 0; - size_t valid_packs_before_search = 0; - size_t valid_packs_after_search = 0; -}; - -} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/CommonUtil.cpp b/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/CommonUtil.cpp new file mode 100644 index 00000000000..3f9d2ebeba9 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/CommonUtil.cpp @@ -0,0 +1,189 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + + +namespace DB::DM::InvertedIndex +{ + +template +void Block::serialize(WriteBuffer & write_buf) const +{ + writeIntBinary(static_cast(entries.size()), write_buf); + // write all values first + for (const auto & entry : entries) + { + writeIntBinary(entry.value, write_buf); + writeIntBinary(static_cast(entry.row_ids.size()), write_buf); + } + // write all row_ids + for (const auto & entry : entries) + { + const auto & row_ids = entry.row_ids; + write_buf.write(reinterpret_cast(row_ids.data()), row_ids.size() * sizeof(RowID)); + } +} + +template +void Block::deserialize(Block & block, ReadBuffer & read_buf) +{ + UInt32 size; + readIntBinary(size, read_buf); + block.entries.resize(size); + for (UInt32 i = 0; i < size; ++i) + { + T value; + readIntBinary(value, read_buf); + UInt32 row_ids_size; + readIntBinary(row_ids_size, read_buf); + block.entries[i].value = value; + block.entries[i].row_ids.resize(row_ids_size); + } + for (UInt32 i = 0; i < size; ++i) + { + auto & entry = block.entries[i]; + read_buf.readStrict(reinterpret_cast(entry.row_ids.data()), entry.row_ids.size() * sizeof(RowID)); + } +} + +template +void Block::search(BitmapFilterPtr & bitmap_filter, ReadBuffer & read_buf, T key) +{ + UInt32 size; + readIntBinary(size, read_buf); + UInt32 seek_offset = size * (sizeof(T) + sizeof(UInt32)); + for (UInt32 i = 0; i < size; ++i) + { + T value; + readIntBinary(value, read_buf); + UInt32 row_ids_size; + readIntBinary(row_ids_size, read_buf); + seek_offset -= (sizeof(T) + sizeof(UInt32)); + if (value == key) + { + // ignore the rest values and previous row_ids + read_buf.ignore(seek_offset); + RowIDs row_ids(row_ids_size); + read_buf.readStrict(reinterpret_cast(row_ids.data()), row_ids_size * sizeof(RowID)); + bitmap_filter->set(row_ids, nullptr); + return; + } + seek_offset += row_ids_size * sizeof(RowID); + } +} + +template +void Block::searchRange(BitmapFilterPtr & bitmap_filter, ReadBuffer & read_buf, T begin, T end) +{ + UInt32 read_count = read_buf.count(); + UInt32 size; + readIntBinary(size, read_buf); + UInt32 acc_row_ids_size = 0; + UInt32 start_offset = 0; + UInt32 end_offset = 0; + for (UInt32 i = 0; i < size; ++i) + { + T value; + readIntBinary(value, read_buf); + UInt32 row_ids_size; + readIntBinary(row_ids_size, read_buf); + if (value >= begin && value <= end && start_offset == 0) + start_offset = sizeof(UInt32) + size * (sizeof(T) + sizeof(UInt32)) + acc_row_ids_size * sizeof(RowID); + acc_row_ids_size += row_ids_size; + if (value >= begin && value <= end) + end_offset = sizeof(UInt32) + size * (sizeof(T) + sizeof(UInt32)) + acc_row_ids_size * sizeof(RowID); + if (value > end) + break; + } + + if (start_offset == 0) + return; + + read_count = read_buf.count() - read_count; + read_buf.ignore(start_offset - read_count); + RowIDs row_ids((end_offset - start_offset) / sizeof(RowID)); + read_buf.readStrict(reinterpret_cast(row_ids.data()), row_ids.size() * sizeof(RowID)); + bitmap_filter->set(row_ids, nullptr); +} + +template +void MetaEntry::serialize(WriteBuffer & write_buf) const +{ + writeIntBinary(offset, write_buf); + writeIntBinary(size, write_buf); + writeIntBinary(min, write_buf); + writeIntBinary(max, write_buf); +} + +template +void MetaEntry::deserialize(MetaEntry & entry, ReadBuffer & read_buf) +{ + readIntBinary(entry.offset, read_buf); + readIntBinary(entry.size, read_buf); + readIntBinary(entry.min, read_buf); + readIntBinary(entry.max, read_buf); +} + +template +void Meta::serialize(WriteBuffer & write_buf) const +{ + writeIntBinary(static_cast(sizeof(T)), write_buf); + writeIntBinary(static_cast(entries.size()), write_buf); + for (const auto & entry : entries) + entry.serialize(write_buf); +} + +template +void Meta::deserialize(Meta & meta, ReadBuffer & read_buf) +{ + UInt8 type_size; + readIntBinary(type_size, read_buf); + RUNTIME_CHECK(type_size == sizeof(T)); + + UInt32 size; + readIntBinary(size, read_buf); + meta.entries.resize(size); + for (auto & entry : meta.entries) + MetaEntry::deserialize(entry, read_buf); +} + +template struct Block; +template struct Block; +template struct Block; +template struct Block; +template struct Block; +template struct Block; +template struct Block; +template struct Block; +template struct MetaEntry; +template struct MetaEntry; +template struct MetaEntry; +template struct MetaEntry; +template struct MetaEntry; +template struct MetaEntry; +template struct MetaEntry; +template struct MetaEntry; +template struct Meta; +template struct Meta; +template struct Meta; +template struct Meta; +template struct Meta; +template struct Meta; +template struct Meta; +template struct Meta; + +} // namespace DB::DM::InvertedIndex diff --git a/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/CommonUtil.h b/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/CommonUtil.h new file mode 100644 index 00000000000..ed34ed741af --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/CommonUtil.h @@ -0,0 +1,99 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace DB::DM::InvertedIndex +{ + +enum class Version +{ + Invalid = 0, + V1 = 1, +}; + +// InvertedIndex file format (V1): +// | VERSION | Block 0 (compressed) | Block 1 (compressed) | ... | Block N (compressed) | Meta | Meta size | Magic flag | + +// Block format: +// | number of values | value | row_ids size | value | row_ids size | ... | value | row_ids size | row_ids | row_ids | ... | row_ids | + +// Meta format: +// | size of T | number of blocks | offset | size | min | max | offset | size | min | max | ... | offset | size | min | max | + +using RowID = UInt32; +using RowIDs = std::vector; + +// A block is a minimal unit of IO, it will be as small as possible, but >= 64KB. +static constexpr size_t BlockSize = 64 * 1024; // 64 KB + +// +template +struct BlockEntry +{ + T value; + RowIDs row_ids; +}; + +template +struct Block +{ + Block() = default; + + std::vector> entries; + void serialize(WriteBuffer & write_buf) const; + static void deserialize(Block & block, ReadBuffer & read_buf); + + static void search(BitmapFilterPtr & bitmap_filter, ReadBuffer & read_buf, T key); + static void searchRange(BitmapFilterPtr & bitmap_filter, ReadBuffer & read_buf, T begin, T end); +}; + +template +struct MetaEntry +{ + UInt32 offset; // offset in the file + UInt32 size; // block size + T min; + T max; + + void serialize(WriteBuffer & write_buf) const; + static void deserialize(MetaEntry & entry, ReadBuffer & read_buf); +}; + +template +struct Meta +{ + Meta() = default; + + std::vector> entries; + void serialize(WriteBuffer & write_buf) const; + static void deserialize(Meta & meta, ReadBuffer & read_buf); +}; + +static std::string_view constexpr MagicFlag = "INVE"; +static UInt32 constexpr MagicFlagLength = MagicFlag.size(); + +// Get the size of the block in bytes. +template +constexpr size_t getBlockSize(UInt32 entry_size, UInt32 row_ids_size) +{ + return sizeof(UInt32) + entry_size * (sizeof(T) + sizeof(UInt32)) + row_ids_size * sizeof(RowID); +} + +} // namespace DB::DM::InvertedIndex diff --git a/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Reader.cpp b/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Reader.cpp new file mode 100644 index 00000000000..4f822e2e8d6 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Reader.cpp @@ -0,0 +1,261 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ +extern const int BAD_ARGUMENTS; +extern const int ABORTED; +} // namespace DB::ErrorCodes + +namespace DB::DM +{ + +InvertedIndexReaderPtr InvertedIndexReader::view(const DataTypePtr & type, std::string_view path) +{ + auto type_id = type->isNullable() ? dynamic_cast(*type).getNestedType()->getTypeId() + : type->getTypeId(); + switch (type_id) + { + case TypeIndex::UInt8: + return std::make_shared>(path); + case TypeIndex::Int8: + return std::make_shared>(path); + case TypeIndex::UInt16: + return std::make_shared>(path); + case TypeIndex::Int16: + return std::make_shared>(path); + case TypeIndex::UInt32: + return std::make_shared>(path); + case TypeIndex::Int32: + return std::make_shared>(path); + case TypeIndex::UInt64: + return std::make_shared>(path); + case TypeIndex::Int64: + return std::make_shared>(path); + case TypeIndex::Date: + return std::make_shared>(path); + case TypeIndex::DateTime: + return std::make_shared>(path); + case TypeIndex::Enum8: + return std::make_shared>(path); + case TypeIndex::Enum16: + return std::make_shared>(path); + case TypeIndex::MyDate: + case TypeIndex::MyDateTime: + case TypeIndex::MyTimeStamp: + return std::make_shared>(path); + case TypeIndex::MyTime: + return std::make_shared>(path); + default: + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unsupported type_id: {}", magic_enum::enum_name(type_id)); + } +} + +InvertedIndexReaderPtr InvertedIndexReader::view(const DataTypePtr & type, ReadBuffer & buf, size_t index_size) +{ + auto type_id = type->isNullable() ? dynamic_cast(*type).getNestedType()->getTypeId() + : type->getTypeId(); + switch (type_id) + { + case TypeIndex::UInt8: + return std::make_shared>(buf, index_size); + case TypeIndex::Int8: + return std::make_shared>(buf, index_size); + case TypeIndex::UInt16: + return std::make_shared>(buf, index_size); + case TypeIndex::Int16: + return std::make_shared>(buf, index_size); + case TypeIndex::UInt32: + return std::make_shared>(buf, index_size); + case TypeIndex::Int32: + return std::make_shared>(buf, index_size); + case TypeIndex::UInt64: + return std::make_shared>(buf, index_size); + case TypeIndex::Int64: + return std::make_shared>(buf, index_size); + case TypeIndex::Date: + return std::make_shared>(buf, index_size); + case TypeIndex::DateTime: + return std::make_shared>(buf, index_size); + case TypeIndex::Enum8: + return std::make_shared>(buf, index_size); + case TypeIndex::Enum16: + return std::make_shared>(buf, index_size); + case TypeIndex::MyDate: + case TypeIndex::MyDateTime: + case TypeIndex::MyTimeStamp: + return std::make_shared>(buf, index_size); + case TypeIndex::MyTime: + return std::make_shared>(buf, index_size); + default: + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unsupported type_id: {}", magic_enum::enum_name(type_id)); + } +} + +template +void InvertedIndexMemoryReader::load(ReadBuffer & read_buf, size_t index_size) +{ + // 0. check version + UInt8 version; + readIntBinary(version, read_buf); + RUNTIME_CHECK(version == magic_enum::enum_integer(InvertedIndex::Version::V1)); + index_size -= sizeof(UInt8); + + // 1. read all data + std::vector buf(index_size); + RUNTIME_CHECK(read_buf.readBig(buf.data(), index_size) == index_size); + + // 2. check magic flag + size_t data_size = index_size - InvertedIndex::MagicFlagLength; + if (memcmp(buf.data() + data_size, InvertedIndex::MagicFlag.data(), InvertedIndex::MagicFlagLength) != 0) + throw Exception(ErrorCodes::ABORTED, "Invalid magic flag"); + + // 3. read meta size + data_size = data_size - sizeof(UInt32); + UInt32 meta_size = *reinterpret_cast(buf.data() + data_size); + + // 4. read meta + ReadBufferFromMemory buffer(buf.data() + data_size - meta_size, meta_size); + InvertedIndex::Meta meta; + data_size = data_size - meta_size; + InvertedIndex::Meta::deserialize(meta, buffer); + + // 5. read blocks & build index + buffer = ReadBufferFromMemory(buf.data(), data_size); + for (const auto meta_entry : meta.entries) + { + auto count = buffer.count(); + InvertedIndex::Block block; + InvertedIndex::Block::deserialize(block, buffer); + RUNTIME_CHECK(buffer.count() - count == meta_entry.size); + for (const auto & block_entry : block.entries) + { + auto [value, row_ids] = block_entry; + index[value] = row_ids; + } + } +} + +template +void InvertedIndexMemoryReader::search(BitmapFilterPtr & bitmap_filter, const Key & key) const +{ + T real_key = key; + auto it = index.find(real_key); + if (it != index.end()) + bitmap_filter->set(it->second, nullptr); +} + +template +void InvertedIndexMemoryReader::searchRange(BitmapFilterPtr & bitmap_filter, const Key & begin, const Key & end) + const +{ + T real_begin = begin; + T real_end = end; + auto index_begin = index.lower_bound(real_begin); + auto index_end = index.upper_bound(real_end); + for (auto it = index_begin; it != index_end; ++it) + bitmap_filter->set(it->second, nullptr); +} + +template +void InvertedIndexFileReader::loadMeta(ReadBuffer & read_buf, size_t index_size) +{ + // 0. check version + UInt8 version; + readIntBinary(version, read_buf); + RUNTIME_CHECK(version == magic_enum::enum_integer(InvertedIndex::Version::V1)); + index_size -= sizeof(UInt8); + + // 1. read all data + std::vector buf(index_size); + RUNTIME_CHECK(read_buf.readBig(buf.data(), index_size) == index_size); + + // 2. check magic flag + size_t data_size = index_size - InvertedIndex::MagicFlagLength; + if (memcmp(buf.data() + data_size, InvertedIndex::MagicFlag.data(), InvertedIndex::MagicFlagLength) != 0) + throw Exception(ErrorCodes::ABORTED, "Invalid magic flag"); + + // 3. read meta size + data_size = data_size - sizeof(UInt32); + UInt32 meta_size = *reinterpret_cast(buf.data() + data_size); + + // 4. read meta + data_size = data_size - meta_size; + ReadBufferFromMemory buffer(buf.data() + data_size, meta_size); + InvertedIndex::Meta::deserialize(meta, buffer); +} + +template +void InvertedIndexFileReader::search(BitmapFilterPtr & bitmap_filter, const Key & key) const +{ + T real_key = key; + auto it = std::find_if(meta.entries.begin(), meta.entries.end(), [&](const auto & entry) { + return entry.min <= real_key && entry.max >= real_key; + }); + if (it == meta.entries.end()) + return; + + ReadBufferFromFile file_buf(path, DBMS_DEFAULT_BUFFER_SIZE, O_RDONLY); + file_buf.seek(it->offset, SEEK_SET); + InvertedIndex::Block::search(bitmap_filter, file_buf, real_key); +} + +template +void InvertedIndexFileReader::searchRange(BitmapFilterPtr & bitmap_filter, const Key & begin, const Key & end) const +{ + T real_begin = begin; + T real_end = end; + // max < begin + auto meta_begin = std::lower_bound( + meta.entries.begin(), + meta.entries.end(), + real_begin, + [](const auto & entry, const auto & key) { return entry.max < key; }); + // min > end + auto meta_end + = std::upper_bound(meta_begin, meta.entries.end(), real_end, [](const auto & key, const auto & entry) { + return key < entry.min; + }); + + ReadBufferFromFile file_buf(path, DBMS_DEFAULT_BUFFER_SIZE, O_RDONLY); + for (auto it = meta_begin; it != meta_end; ++it) + { + file_buf.seek(it->offset, SEEK_SET); + InvertedIndex::Block::searchRange(bitmap_filter, file_buf, real_begin, real_end); + } +} + +template class InvertedIndexMemoryReader; +template class InvertedIndexMemoryReader; +template class InvertedIndexMemoryReader; +template class InvertedIndexMemoryReader; +template class InvertedIndexMemoryReader; +template class InvertedIndexMemoryReader; +template class InvertedIndexMemoryReader; +template class InvertedIndexMemoryReader; +template class InvertedIndexFileReader; +template class InvertedIndexFileReader; +template class InvertedIndexFileReader; +template class InvertedIndexFileReader; +template class InvertedIndexFileReader; +template class InvertedIndexFileReader; +template class InvertedIndexFileReader; +template class InvertedIndexFileReader; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Reader.h b/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Reader.h new file mode 100644 index 00000000000..72aadf073b3 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Reader.h @@ -0,0 +1,99 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace DB::DM +{ +/// Read a InvertedIndex file. +class InvertedIndexReader : public ICacheableLocalIndexReader +{ +public: + using Key = UInt64; + +public: + explicit InvertedIndexReader() = default; + ~InvertedIndexReader() override = default; + + static InvertedIndexReaderPtr view(const DataTypePtr & type, std::string_view path); + static InvertedIndexReaderPtr view(const DataTypePtr & type, ReadBuffer & buf, size_t index_size); + + virtual void search(BitmapFilterPtr & bitmap_filter, const Key & key) const = 0; + // [begin, end] + virtual void searchRange(BitmapFilterPtr & bitmap_filter, const Key & begin, const Key & end) const = 0; +}; + +/// Read a InvertedIndex file by loading it into memory. +/// Its performance is better than InvertedIndexFileReader but it consumes more memory. +template +class InvertedIndexMemoryReader : public InvertedIndexReader +{ +private: + void load(ReadBuffer & buf, size_t index_size); + +public: + explicit InvertedIndexMemoryReader(std::string_view path) + { + ReadBufferFromFile buf(path.data()); + load(buf, Poco::File(path.data()).getSize()); + } + + InvertedIndexMemoryReader(ReadBuffer & buf, size_t index_size) { load(buf, index_size); } + + ~InvertedIndexMemoryReader() override = default; + + void search(BitmapFilterPtr & bitmap_filter, const Key & key) const override; + void searchRange(BitmapFilterPtr & bitmap_filter, const Key & begin, const Key & end) const override; + +private: + std::map> index; // set by load +}; + +/// Read a InvertedIndex file by reading it from disk. +/// Its memory usage is minimal but its performance is worse than InvertedIndexMemoryReader. +template +class InvertedIndexFileReader : public InvertedIndexReader +{ +private: + void loadMeta(ReadBuffer & buf, size_t index_size); + +public: + explicit InvertedIndexFileReader(std::string_view path) + : path(path) + { + ReadBufferFromFile buffer(path.data()); + loadMeta(buffer, Poco::File(path.data()).getSize()); + } + + ~InvertedIndexFileReader() override = default; + + void search(BitmapFilterPtr & bitmap_filter, const Key & key) const override; + void searchRange(BitmapFilterPtr & bitmap_filter, const Key & begin, const Key & end) const override; + +private: + // Since this viewer will be used in multiple threads, + // only store the path and load the file when needed. + // Warning: Do not shared file_buf between threads. + const String path; + InvertedIndex::Meta meta; // set by loadMeta +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Perf.cpp b/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Reader_fwd.h similarity index 73% rename from dbms/src/Storages/DeltaMerge/Index/VectorIndex/Perf.cpp rename to dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Reader_fwd.h index 18b075a0ae8..c591b86ea5e 100644 --- a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Perf.cpp +++ b/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Reader_fwd.h @@ -1,4 +1,4 @@ -// Copyright 2024 PingCAP, Inc. +// Copyright 2025 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,11 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#pragma once -namespace DB::PerfContext +#include + +namespace DB::DM { -thread_local VectorSearchPerfContext vector_search = {}; +class InvertedIndexReader; +using InvertedIndexReaderPtr = std::shared_ptr; -} +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Writer.cpp b/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Writer.cpp new file mode 100644 index 00000000000..58cc4e2a17e --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Writer.cpp @@ -0,0 +1,280 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include + +namespace DB::ErrorCodes +{ +extern const int ABORTED; +extern const int BAD_ARGUMENTS; +} // namespace DB::ErrorCodes + +namespace DB::DM +{ + +template +void InvertedIndexWriterInternal::addBlock( + const IColumn & column, + const ColumnVector * del_mark, + ProceedCheckFn should_proceed) +{ + // Note: column may be nullable. + const bool is_nullable = column.isColumnNullable(); + const auto * col_vector + = is_nullable ? checkAndGetNestedColumn>(&column) : checkAndGetColumn>(&column); + RUNTIME_CHECK_MSG(col_vector, "ColumnVector is expected, get: {}, T: {}", column.getName(), typeid(T).name()); + const auto & col_data = col_vector->getData(); + + const auto * null_map = is_nullable ? &(checkAndGetColumn(&column)->getNullMapData()) : nullptr; + const auto * del_mark_data = del_mark ? &(del_mark->getData()) : nullptr; + + Stopwatch w; + SCOPE_EXIT({ total_duration += w.elapsedSeconds(); }); + + Stopwatch w_proceed_check(CLOCK_MONOTONIC_COARSE); + + for (size_t i = 0; i < col_data.size(); ++i) + { + auto row_offset = added_rows; + ++added_rows; + + if (unlikely(i % 100 == 0 && w_proceed_check.elapsedSeconds() > 0.5)) + { + // The check of should_proceed could be non-trivial, so do it not too often. + w_proceed_check.restart(); + if (!should_proceed()) + throw Exception(ErrorCodes::ABORTED, "Index build is interrupted"); + } + + // Ignore rows with del_mark, as the column values are not meaningful. + if (del_mark_data != nullptr && (*del_mark_data)[i]) + continue; + + // Ignore NULL values, as they are not meaningful to store in index. + if (null_map && (*null_map)[i]) + continue; + + index[col_data[i]].push_back(row_offset); + } +} + +template +void InvertedIndexWriterInternal::saveToBuffer(WriteBuffer & write_buf) const +{ + size_t offset = 0; + + // 0. write version + UInt8 version = magic_enum::enum_integer(InvertedIndex::Version::V1); + writeIntBinary(version, write_buf); + offset += sizeof(UInt8); + + InvertedIndex::Meta meta; + + // 1. write data by block + InvertedIndex::Block block; + size_t row_ids_size = 0; + auto write_block = [&] { + block.serialize(write_buf); + size_t total_size = write_buf.count(); + meta.entries.emplace_back(offset, total_size - offset, block.entries.front().value, block.entries.back().value); + block.entries.clear(); + offset = total_size; + row_ids_size = 0; + }; + + for (const auto & [key, row_ids] : index) + { + block.entries.emplace_back(key, row_ids); + row_ids_size += row_ids.size(); + + // write block + if (InvertedIndex::getBlockSize(block.entries.size(), row_ids_size) >= InvertedIndex::BlockSize) + write_block(); + } + if (!block.entries.empty()) + write_block(); + + // 2. write meta + offset = write_buf.count(); + meta.serialize(write_buf); + + // 3. write meta size + UInt32 meta_size = write_buf.count() - offset; + write_buf.write(reinterpret_cast(&meta_size), sizeof(meta_size)); + + // 4. write magic flag + write_buf.write(InvertedIndex::MagicFlag.data(), InvertedIndex::MagicFlagLength); + + // 5. record uncompressed size + uncompressed_size = write_buf.count(); +} + +template +void InvertedIndexWriterInternal::saveFileProps(dtpb::IndexFilePropsV2 * pb_idx) const +{ + auto * pb_inv_idx = pb_idx->mutable_inverted_index(); + pb_inv_idx->set_uncompressed_size(uncompressed_size); +} + +template +InvertedIndexWriterInternal::~InvertedIndexWriterInternal() +{ + GET_METRIC(tiflash_inverted_index_duration, type_build).Observe(total_duration); + GET_METRIC(tiflash_inverted_index_active_instances, type_build).Decrement(); +} + +template +void InvertedIndexWriterOnDisk::saveToFile() const +{ + Stopwatch w; + SCOPE_EXIT({ writer.total_duration += w.elapsedSeconds(); }); + + WriteBufferFromFile write_buf(index_file); + writer.saveToBuffer(write_buf); + write_buf.sync(); +} + +template +void InvertedIndexWriterInMemory::saveToBuffer(WriteBuffer & write_buf) +{ + Stopwatch w; + SCOPE_EXIT({ writer.total_duration += w.elapsedSeconds(); }); + + writer.saveToBuffer(write_buf); +} + +template class InvertedIndexWriterInternal; +template class InvertedIndexWriterInternal; +template class InvertedIndexWriterInternal; +template class InvertedIndexWriterInternal; +template class InvertedIndexWriterInternal; +template class InvertedIndexWriterInternal; +template class InvertedIndexWriterInternal; +template class InvertedIndexWriterInternal; +template class InvertedIndexWriterOnDisk; +template class InvertedIndexWriterOnDisk; +template class InvertedIndexWriterOnDisk; +template class InvertedIndexWriterOnDisk; +template class InvertedIndexWriterOnDisk; +template class InvertedIndexWriterOnDisk; +template class InvertedIndexWriterOnDisk; +template class InvertedIndexWriterOnDisk; +template class InvertedIndexWriterInMemory; +template class InvertedIndexWriterInMemory; +template class InvertedIndexWriterInMemory; +template class InvertedIndexWriterInMemory; +template class InvertedIndexWriterInMemory; +template class InvertedIndexWriterInMemory; +template class InvertedIndexWriterInMemory; +template class InvertedIndexWriterInMemory; + +LocalIndexWriterOnDiskPtr createOnDiskInvertedIndexWriter( + IndexID index_id, + std::string_view index_file, + const TiDB::InvertedIndexDefinitionPtr & definition) +{ + if (!definition) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Invalid index kind or definition"); + + if (definition->type_size == sizeof(UInt8) && !definition->is_signed) + { + return std::make_shared>(index_id, index_file); + } + else if (definition->type_size == sizeof(Int8) && definition->is_signed) + { + return std::make_shared>(index_id, index_file); + } + else if (definition->type_size == sizeof(UInt16) && !definition->is_signed) + { + return std::make_shared>(index_id, index_file); + } + else if (definition->type_size == sizeof(Int16) && definition->is_signed) + { + return std::make_shared>(index_id, index_file); + } + else if (definition->type_size == sizeof(UInt32) && !definition->is_signed) + { + return std::make_shared>(index_id, index_file); + } + else if (definition->type_size == sizeof(Int32) && definition->is_signed) + { + return std::make_shared>(index_id, index_file); + } + else if (definition->type_size == sizeof(UInt64) && !definition->is_signed) + { + return std::make_shared>(index_id, index_file); + } + else if (definition->type_size == sizeof(Int64) && definition->is_signed) + { + return std::make_shared>(index_id, index_file); + } + else + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unsupported type size {}", definition->type_size); + } +} + +LocalIndexWriterInMemoryPtr createInMemoryInvertedIndexWriter( + IndexID index_id, + const TiDB::InvertedIndexDefinitionPtr & definition) +{ + if (!definition) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Invalid index kind or definition"); + + if (definition->type_size == sizeof(UInt8) && !definition->is_signed) + { + return std::make_shared>(index_id); + } + else if (definition->type_size == sizeof(Int8) && definition->is_signed) + { + return std::make_shared>(index_id); + } + else if (definition->type_size == sizeof(UInt16) && !definition->is_signed) + { + return std::make_shared>(index_id); + } + else if (definition->type_size == sizeof(Int16) && definition->is_signed) + { + return std::make_shared>(index_id); + } + else if (definition->type_size == sizeof(UInt32) && !definition->is_signed) + { + return std::make_shared>(index_id); + } + else if (definition->type_size == sizeof(Int32) && definition->is_signed) + { + return std::make_shared>(index_id); + } + else if (definition->type_size == sizeof(UInt64) && !definition->is_signed) + { + return std::make_shared>(index_id); + } + else if (definition->type_size == sizeof(Int64) && definition->is_signed) + { + return std::make_shared>(index_id); + } + else + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unsupported type size {}", definition->type_size); + } +} + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Writer.h b/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Writer.h new file mode 100644 index 00000000000..aa56cb72c75 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/InvertedIndex/Writer.h @@ -0,0 +1,105 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace DB::DM +{ + +template +class InvertedIndexWriterInternal +{ +public: + using Key = T; + using RowID = InvertedIndex::RowID; + +public: + ~InvertedIndexWriterInternal(); + + using ProceedCheckFn = LocalIndexWriter::ProceedCheckFn; + void addBlock(const IColumn & column, const ColumnVector * del_mark, ProceedCheckFn should_proceed); + + void saveFileProps(dtpb::IndexFilePropsV2 * pb_idx) const; + void saveToBuffer(WriteBuffer & write_buf) const; + +public: + UInt64 added_rows = 0; // Includes nulls and deletes. Used as the index key. + std::map> index; + mutable double total_duration = 0; + mutable size_t uncompressed_size = 0; +}; + +template +class InvertedIndexWriterInMemory : public LocalIndexWriterInMemory +{ +public: + explicit InvertedIndexWriterInMemory(IndexID index_id) + : LocalIndexWriterInMemory(index_id) + , writer() + {} + + void saveToBuffer(WriteBuffer & write_buf) override; + + void addBlock(const IColumn & column, const ColumnVector * del_mark, ProceedCheckFn should_proceed) override + { + writer.addBlock(column, del_mark, should_proceed); + } + + void saveFileProps(dtpb::IndexFilePropsV2 * pb_idx) const override { writer.saveFileProps(pb_idx); } + + dtpb::IndexFileKind kind() const override { return dtpb::IndexFileKind::INVERTED_INDEX; } + +private: + InvertedIndexWriterInternal writer; +}; + +// FIXME: The building process is still in-memory. Only the index file is saved to disk. +template +class InvertedIndexWriterOnDisk : public LocalIndexWriterOnDisk +{ +public: + explicit InvertedIndexWriterOnDisk(IndexID index_id, std::string_view index_file) + : LocalIndexWriterOnDisk(index_id, index_file) + , writer() + {} + + void saveToFile() const override; + + void addBlock(const IColumn & column, const ColumnVector * del_mark, ProceedCheckFn should_proceed) override + { + writer.addBlock(column, del_mark, should_proceed); + } + + void saveFileProps(dtpb::IndexFilePropsV2 * pb_idx) const override { writer.saveFileProps(pb_idx); } + + dtpb::IndexFileKind kind() const override { return dtpb::IndexFileKind::INVERTED_INDEX; } + +private: + InvertedIndexWriterInternal writer; +}; + + +LocalIndexWriterOnDiskPtr createOnDiskInvertedIndexWriter( + IndexID index_id, + std::string_view index_file, + const TiDB::InvertedIndexDefinitionPtr & definition); + +LocalIndexWriterInMemoryPtr createInMemoryInvertedIndexWriter( + IndexID index_id, + const TiDB::InvertedIndexDefinitionPtr & definition); + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/LocalIndexInfo.cpp b/dbms/src/Storages/DeltaMerge/Index/LocalIndexInfo.cpp index 123003830b1..b58832626ae 100644 --- a/dbms/src/Storages/DeltaMerge/Index/LocalIndexInfo.cpp +++ b/dbms/src/Storages/DeltaMerge/Index/LocalIndexInfo.cpp @@ -180,13 +180,7 @@ LocalIndexInfosChangeset generateLocalIndexInfos( if (idx.state == TiDB::StatePublic || idx.state == TiDB::StateWriteReorganization) { // create a new index - new_index_infos->emplace_back(LocalIndexInfo{ - .kind = idx.columnarIndexKind(), - .index_id = idx.id, - .column_id = column_id, - // Only one of the below will be set - .def_vector_index = idx.vector_index, - }); + new_index_infos->emplace_back(LocalIndexInfo(idx.id, column_id, idx.vector_index)); newly_added.emplace_back(idx.id); index_ids_in_new_table.emplace(idx.id); } diff --git a/dbms/src/Storages/DeltaMerge/Index/LocalIndexInfo.h b/dbms/src/Storages/DeltaMerge/Index/LocalIndexInfo.h index 9671e571931..9a87e6ffac7 100644 --- a/dbms/src/Storages/DeltaMerge/Index/LocalIndexInfo.h +++ b/dbms/src/Storages/DeltaMerge/Index/LocalIndexInfo.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -45,6 +46,21 @@ struct LocalIndexInfo ColumnID column_id = DB::EmptyColumnID; TiDB::VectorIndexDefinitionPtr def_vector_index = nullptr; + TiDB::InvertedIndexDefinitionPtr def_inverted_index = nullptr; + + LocalIndexInfo(IndexID index_id_, ColumnID column_id_, const TiDB::VectorIndexDefinitionPtr & def) + : kind(TiDB::ColumnarIndexKind::Vector) + , index_id(index_id_) + , column_id(column_id_) + , def_vector_index(def) + {} + + LocalIndexInfo(IndexID index_id_, ColumnID column_id_, const TiDB::InvertedIndexDefinitionPtr & def) + : kind(TiDB::ColumnarIndexKind::Inverted) + , index_id(index_id_) + , column_id(column_id_) + , def_inverted_index(def) + {} dtpb::IndexFileKind getKindAsDtpb() const { @@ -52,6 +68,8 @@ struct LocalIndexInfo { case TiDB::ColumnarIndexKind::Vector: return dtpb::IndexFileKind::VECTOR_INDEX; + case TiDB::ColumnarIndexKind::Inverted: + return dtpb::IndexFileKind::INVERTED_INDEX; default: RUNTIME_CHECK_MSG(false, "Unsupported index kind: {}", magic_enum::enum_name(kind)); } diff --git a/dbms/src/Storages/DeltaMerge/Index/LocalIndexWriter.cpp b/dbms/src/Storages/DeltaMerge/Index/LocalIndexWriter.cpp new file mode 100644 index 00000000000..0212aec07fa --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/LocalIndexWriter.cpp @@ -0,0 +1,74 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + + +namespace DB::DM +{ + +LocalIndexWriterInMemoryPtr LocalIndexWriter::createInMemory(const LocalIndexInfo & index_info) +{ + switch (index_info.kind) + { + case TiDB::ColumnarIndexKind::Vector: + return std::make_shared(index_info.index_id, index_info.def_vector_index); + case TiDB::ColumnarIndexKind::Inverted: + return createInMemoryInvertedIndexWriter(index_info.index_id, index_info.def_inverted_index); + default: + RUNTIME_CHECK_MSG(false, "Unsupported index kind: {}", magic_enum::enum_name(index_info.kind)); + } +} + +LocalIndexWriterOnDiskPtr LocalIndexWriter::createOnDisk(std::string_view index_file, const LocalIndexInfo & index_info) +{ + switch (index_info.kind) + { + case TiDB::ColumnarIndexKind::Vector: + return std::make_shared(index_info.index_id, index_file, index_info.def_vector_index); + case TiDB::ColumnarIndexKind::Inverted: + return createOnDiskInvertedIndexWriter(index_info.index_id, index_file, index_info.def_inverted_index); + default: + RUNTIME_CHECK_MSG(false, "Unsupported index kind: {}", magic_enum::enum_name(index_info.kind)); + } +} + +dtpb::IndexFilePropsV2 LocalIndexWriterOnDisk::finalize() +{ + saveToFile(); + dtpb::IndexFilePropsV2 pb_idx; + pb_idx.set_index_id(index_id); + pb_idx.set_file_size(Poco::File(index_file.data()).getSize()); + pb_idx.set_kind(kind()); + saveFileProps(&pb_idx); + return pb_idx; +} + +dtpb::IndexFilePropsV2 LocalIndexWriterInMemory::finalize( + WriteBuffer & write_buf, + std::function get_materialized_size) +{ + saveToBuffer(write_buf); + dtpb::IndexFilePropsV2 pb_idx; + pb_idx.set_index_id(index_id); + pb_idx.set_file_size(get_materialized_size()); + pb_idx.set_kind(kind()); + saveFileProps(&pb_idx); + return pb_idx; +} + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/LocalIndexWriter.h b/dbms/src/Storages/DeltaMerge/Index/LocalIndexWriter.h new file mode 100644 index 00000000000..e1ea3357eb2 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/LocalIndexWriter.h @@ -0,0 +1,89 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + + +namespace DB::DM +{ + +class LocalIndexWriter +{ +public: + using ProceedCheckFn = std::function; + +public: + explicit LocalIndexWriter(IndexID index_id_) + : index_id(index_id_) + {} + + static LocalIndexWriterInMemoryPtr createInMemory(const LocalIndexInfo & index_info); + static LocalIndexWriterOnDiskPtr createOnDisk(std::string_view index_file, const LocalIndexInfo & index_info); + + virtual ~LocalIndexWriter() = default; + + virtual void addBlock(const IColumn & column, const ColumnVector * del_mark, ProceedCheckFn should_proceed) + = 0; + +protected: + virtual void saveFileProps(dtpb::IndexFilePropsV2 * pb_idx) const = 0; + + virtual dtpb::IndexFileKind kind() const = 0; + +protected: + IndexID index_id; +}; + +class LocalIndexWriterInMemory : public LocalIndexWriter +{ +public: + explicit LocalIndexWriterInMemory(IndexID index_id_) + : LocalIndexWriter(index_id_) + {} + + ~LocalIndexWriterInMemory() override = default; + + dtpb::IndexFilePropsV2 finalize(WriteBuffer & write_buf, std::function get_materialized_size); + +protected: + virtual void saveToBuffer(WriteBuffer & write_buf) = 0; +}; + +class LocalIndexWriterOnDisk : public LocalIndexWriter +{ +public: + explicit LocalIndexWriterOnDisk(IndexID index_id_, std::string_view index_file_) + : LocalIndexWriter(index_id_) + , index_file(index_file_) + {} + + ~LocalIndexWriterOnDisk() override = default; + + dtpb::IndexFilePropsV2 finalize(); + +protected: + virtual void saveToFile() const = 0; + +protected: + String index_file; +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/LocalIndexWriter_fwd.h b/dbms/src/Storages/DeltaMerge/Index/LocalIndexWriter_fwd.h new file mode 100644 index 00000000000..81b45ae42cb --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/LocalIndexWriter_fwd.h @@ -0,0 +1,31 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace DB::DM +{ + +class LocalIndexWriter; +using LocalIndexWriterPtr = std::shared_ptr; + +class LocalIndexWriterInMemory; +using LocalIndexWriterInMemoryPtr = std::shared_ptr; + +class LocalIndexWriterOnDisk; +using LocalIndexWriterOnDiskPtr = std::shared_ptr; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Perf.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Perf.h index 6fb3f1a7405..1032a292b8d 100644 --- a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Perf.h +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Perf.h @@ -14,24 +14,57 @@ #pragma once +#include #include -/// Remove the population of thread_local from Poco -#ifdef thread_local -#undef thread_local -#endif +#include -namespace DB::PerfContext +namespace DB::DM { -struct VectorSearchPerfContext +/// Unlike FileCache, this is supposed to be wrapped inside a shared_ptr because an input stream could be +/// accessed in different threads (although there is no concurrent access) throughout its lifetime. So +/// thread local is no longer useful. +struct VectorIndexPerf { - size_t visited_nodes = 0; - size_t discarded_nodes = 0; // Rows filtered out by MVCC + uint32_t n_from_cf_index = 0; + uint32_t n_from_cf_noindex = 0; + uint32_t n_from_dmf_index = 0; + uint32_t n_from_dmf_noindex = 0; - void reset() { *this = {}; } -}; + // ============================================================ + // Below: During search + uint32_t n_searches = 0; + uint32_t total_search_ms = 0; + uint64_t visited_nodes = 0; + uint64_t discarded_nodes = 0; // Rows filtered out by MVCC + uint64_t returned_nodes = 0; + uint32_t n_dm_searches = 0; // For calculating avg below + uint32_t dm_packs_in_file = 0; + uint32_t dm_packs_before_search = 0; + uint32_t dm_packs_after_search = 0; + // ============================================================ + + // ============================================================ + // Below: During loading index + uint32_t total_load_ms = 0; + uint32_t load_from_cache = 0; + uint32_t load_from_column_file = 0; // contains ColumnFile load time + uint32_t load_from_stable_s3 = 0; + uint32_t load_from_stable_disk = 0; + // ============================================================ -extern thread_local VectorSearchPerfContext vector_search; + // ============================================================ + // Below: During reading column data + uint32_t n_dm_reads = 0; // For calculating avg below + uint32_t total_dm_read_vec_ms = 0; + uint32_t total_dm_read_others_ms = 0; + uint32_t n_cf_reads = 0; // For calculating avg below + uint32_t total_cf_read_vec_ms = 0; + uint32_t total_cf_read_others_ms = 0; + // ============================================================ + + static VectorIndexPerfPtr create() { return std::make_shared(); } +}; -} // namespace DB::PerfContext +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream_fwd.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Perf_fwd.h similarity index 81% rename from dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream_fwd.h rename to dbms/src/Storages/DeltaMerge/Index/VectorIndex/Perf_fwd.h index 6e88a873070..8ea9c6f427b 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream_fwd.h +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Perf_fwd.h @@ -19,8 +19,8 @@ namespace DB::DM { -class DMFileWithVectorIndexBlockInputStream; +struct VectorIndexPerf; -using DMFileWithVectorIndexBlockInputStreamPtr = std::shared_ptr; +using VectorIndexPerfPtr = std::shared_ptr; } // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Reader.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Reader.cpp index d23453dff8f..1d069267e2c 100644 --- a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Reader.cpp +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Reader.cpp @@ -30,16 +30,14 @@ namespace DB::DM VectorIndexReaderPtr VectorIndexReader::createFromMmap( const dtpb::IndexFilePropsV2Vector & file_props, + const VectorIndexPerfPtr & perf, std::string_view path) { tipb::VectorDistanceMetric metric; RUNTIME_CHECK(tipb::VectorDistanceMetric_Parse(file_props.distance_metric(), &metric)); RUNTIME_CHECK(metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC); - Stopwatch w; - SCOPE_EXIT({ GET_METRIC(tiflash_vector_index_duration, type_view).Observe(w.elapsedSeconds()); }); - - auto vi = std::make_shared(file_props); + auto vi = std::make_shared(/* is_in_memory */ false, file_props, perf); vi->index = USearchImplType::make( unum::usearch::metric_punned_t( // @@ -71,16 +69,14 @@ VectorIndexReaderPtr VectorIndexReader::createFromMmap( VectorIndexReaderPtr VectorIndexReader::createFromMemory( const dtpb::IndexFilePropsV2Vector & file_props, + const VectorIndexPerfPtr & perf, ReadBuffer & buf) { tipb::VectorDistanceMetric metric; RUNTIME_CHECK(tipb::VectorDistanceMetric_Parse(file_props.distance_metric(), &metric)); RUNTIME_CHECK(metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC); - Stopwatch w; - SCOPE_EXIT({ GET_METRIC(tiflash_vector_index_duration, type_view).Observe(w.elapsedSeconds()); }); - - auto vi = std::make_shared(file_props); + auto vi = std::make_shared(/* is_in_memory */ true, file_props, perf); vi->index = USearchImplType::make( unum::usearch::metric_punned_t( // @@ -101,13 +97,15 @@ VectorIndexReaderPtr VectorIndexReader::createFromMemory( RUNTIME_CHECK_MSG(result, "Failed to load vector index: {}", result.error.what()); auto current_memory_usage = vi->index.memory_usage(); - GET_METRIC(tiflash_vector_index_memory_usage, type_view).Increment(static_cast(current_memory_usage)); + GET_METRIC(tiflash_vector_index_memory_usage, type_load).Increment(static_cast(current_memory_usage)); vi->last_reported_memory_usage = current_memory_usage; return vi; } -auto VectorIndexReader::searchImpl(const ANNQueryInfoPtr & query_info, const RowFilter & valid_rows) const +VectorIndexReader::SearchResults VectorIndexReader::search( + const ANNQueryInfoPtr & query_info, + const RowFilter & valid_rows) const { RUNTIME_CHECK(query_info->ref_vec_f32().size() >= sizeof(UInt32)); auto query_vec_size = readLittleEndian(query_info->ref_vec_f32().data()); @@ -141,9 +139,9 @@ auto VectorIndexReader::searchImpl(const ANNQueryInfoPtr & query_info, const Row try { // Note: We don't increase the thread_local perf, because search runs on other threads. - visited_nodes++; + visited_nodes.fetch_add(1, std::memory_order_relaxed); if (!valid_rows[key]) - discarded_nodes++; + discarded_nodes.fetch_add(1, std::memory_order_relaxed); return valid_rows[key]; } catch (...) @@ -154,8 +152,13 @@ auto VectorIndexReader::searchImpl(const ANNQueryInfoPtr & query_info, const Row } }; - Stopwatch w; - SCOPE_EXIT({ GET_METRIC(tiflash_vector_index_duration, type_search).Observe(w.elapsedSeconds()); }); + Stopwatch w(CLOCK_MONOTONIC_COARSE); + SCOPE_EXIT({ + double elapsed = w.elapsedSeconds(); + perf->n_searches += 1; + perf->total_search_ms += elapsed * 1000; + GET_METRIC(tiflash_vector_index_duration, type_search).Observe(w.elapsedSeconds()); + }); // TODO(vector-index): Support efSearch. auto result = index.filtered_search( // @@ -163,37 +166,17 @@ auto VectorIndexReader::searchImpl(const ANNQueryInfoPtr & query_info, const Row query_info->top_k(), predicate); + perf->visited_nodes += visited_nodes; + perf->discarded_nodes += discarded_nodes; + perf->returned_nodes += result.size(); + if (has_exception_in_search) throw Exception(ErrorCodes::INCORRECT_QUERY, "Exception happened occurred during search"); - PerfContext::vector_search.visited_nodes += visited_nodes; - PerfContext::vector_search.discarded_nodes += discarded_nodes; - return result; -} - -std::vector VectorIndexReader::search( - const ANNQueryInfoPtr & query_info, - const RowFilter & valid_rows) const -{ - auto result = searchImpl(query_info, valid_rows); - - // For some reason usearch does not always do the predicate for all search results. - // So we need to filter again. - const size_t result_size = result.size(); - std::vector search_results; - search_results.reserve(result_size); - for (size_t i = 0; i < result_size; ++i) - { - const auto rowid = result[i].member.key; - if (valid_rows[rowid]) - search_results.emplace_back(rowid, result[i].distance); - } - return search_results; -} + if (result.error) + throw Exception(ErrorCodes::INCORRECT_QUERY, "Search resulted in an error: {}", result.error.what()); -size_t VectorIndexReader::size() const -{ - return index.size(); + return result; } void VectorIndexReader::get(Key key, std::vector & out) const @@ -202,19 +185,42 @@ void VectorIndexReader::get(Key key, std::vector & out) const index.get(key, out.data()); } -VectorIndexReader::VectorIndexReader(const dtpb::IndexFilePropsV2Vector & file_props_) - : file_props(file_props_) +VectorIndexReader::VectorIndexReader( + bool is_in_memory_, + const dtpb::IndexFilePropsV2Vector & file_props_, + const VectorIndexPerfPtr & perf_) + : is_in_memory(is_in_memory_) + , file_props(file_props_) + , perf(perf_) { + RUNTIME_CHECK(perf_ != nullptr); RUNTIME_CHECK(file_props.dimensions() > 0); RUNTIME_CHECK(file_props.dimensions() <= TiDB::MAX_VECTOR_DIMENSION); - GET_METRIC(tiflash_vector_index_active_instances, type_view).Increment(); + if (is_in_memory) + { + GET_METRIC(tiflash_vector_index_active_instances, type_load).Increment(); + } + else + { + GET_METRIC(tiflash_vector_index_active_instances, type_view).Increment(); + } } VectorIndexReader::~VectorIndexReader() { - GET_METRIC(tiflash_vector_index_memory_usage, type_view).Decrement(static_cast(last_reported_memory_usage)); - GET_METRIC(tiflash_vector_index_active_instances, type_view).Decrement(); + if (is_in_memory) + { + GET_METRIC(tiflash_vector_index_memory_usage, type_load) + .Decrement(static_cast(last_reported_memory_usage)); + GET_METRIC(tiflash_vector_index_active_instances, type_load).Decrement(); + } + else + { + GET_METRIC(tiflash_vector_index_memory_usage, type_view) + .Decrement(static_cast(last_reported_memory_usage)); + GET_METRIC(tiflash_vector_index_active_instances, type_view).Decrement(); + } } } // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Reader.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Reader.h index 7b49968f763..0772a77a3d3 100644 --- a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Reader.h +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Reader.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -35,44 +36,49 @@ class VectorIndexReader : public ICacheableLocalIndexReader public: /// The key is the row's offset in the DMFile. using Key = UInt32; - using Distance = Float32; - struct SearchResult - { - Key key; - Distance distance; - }; + using SearchResults = USearchImplType::search_result_t; /// True bit means the row is valid and should be kept in the search result. /// False bit lets the row filtered out and will search for more results. using RowFilter = BitmapFilterView; public: - static VectorIndexReaderPtr createFromMmap(const dtpb::IndexFilePropsV2Vector & file_props, std::string_view path); - static VectorIndexReaderPtr createFromMemory(const dtpb::IndexFilePropsV2Vector & file_props, ReadBuffer & buf); + static VectorIndexReaderPtr createFromMmap( + const dtpb::IndexFilePropsV2Vector & file_props, + const VectorIndexPerfPtr & perf, // must not be null + std::string_view path); + static VectorIndexReaderPtr createFromMemory( + const dtpb::IndexFilePropsV2Vector & file_props, + const VectorIndexPerfPtr & perf, // must not be null + ReadBuffer & buf); public: - explicit VectorIndexReader(const dtpb::IndexFilePropsV2Vector & file_props_); + explicit VectorIndexReader( + bool is_in_memory_, + const dtpb::IndexFilePropsV2Vector & file_props_, + const VectorIndexPerfPtr & perf_ // must not be null + ); ~VectorIndexReader() override; - // Invalid rows in `valid_rows` will be discared when applying the search - std::vector search(const ANNQueryInfoPtr & query_info, const RowFilter & valid_rows) const; - - size_t size() const; + /// The result is sorted by distance. + /// WARNING: Due to usearch's impl, invalid rows in `valid_rows` may be still contained in the search result. + /// WARNING: Drop the result as soon as possible, because it is "reader local", blocks more concurrent reads. + /// We choose to return search result directly without any copying to improve performance. + SearchResults search(const ANNQueryInfoPtr & query_info, const RowFilter & valid_rows) const; // Get the value (i.e. vector content) of a Key. void get(Key key, std::vector & out) const; -private: - auto searchImpl(const ANNQueryInfoPtr & query_info, const RowFilter & valid_rows) const; - public: + const bool is_in_memory; const dtpb::IndexFilePropsV2Vector file_props; private: USearchImplType index; + const VectorIndexPerfPtr perf; size_t last_reported_memory_usage = 0; }; diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ColumnFileInputStream.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ColumnFileInputStream.cpp new file mode 100644 index 00000000000..4c8e0a8e318 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ColumnFileInputStream.cpp @@ -0,0 +1,125 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB::DM +{ + +SkippableBlockInputStreamPtr ColumnFileProvideVectorIndexInputStream::createOrFallback( + const VectorIndexStreamCtxPtr & ctx, + const ColumnFilePtr & column_file) +{ + RUNTIME_CHECK(ctx->data_provider != nullptr); + RUNTIME_CHECK(ctx->dm_context != nullptr); + + auto fallback = [&] { + ctx->perf->n_from_cf_noindex += 1; + return ColumnFileInputStream::create( + *ctx->dm_context, + column_file, + ctx->data_provider, + ctx->col_defs, + ctx->read_tag); + }; + + const auto tiny_file = std::dynamic_pointer_cast(column_file); + if (!tiny_file) + return fallback(); + const auto * index_info = tiny_file->findIndexInfo(ctx->ann_query_info->index_id()); + if (!index_info) + return fallback(); + + ctx->perf->n_from_cf_index += 1; + return std::make_shared(ctx, tiny_file); +} + +inline VectorIndexReaderPtr ColumnFileProvideVectorIndexInputStream::getVectorIndexReader() +{ + if (vec_index != nullptr) + return vec_index; + vec_index = VectorIndexReaderFromColumnFileTiny::load(ctx, *tiny_file); + return vec_index; +} + +Block ColumnFileProvideVectorIndexInputStream::read() +{ + // We require setReturnRows to be called before read. + // See ConcatInputStream for a caller, e.g.: VectorIndexInputStream{ConcatInputStream{ColumnFileProvideVectorIndexInputStream}} + RUNTIME_CHECK(sorted_results.owner != nullptr); + RUNTIME_CHECK(vec_index != nullptr); + + // All rows are filtered out (or there is already a read()) + if (sorted_results.view.empty()) + return {}; + + Stopwatch w(CLOCK_MONOTONIC_COARSE); + + // read vector column from index + auto vec_column = ctx->vec_cd.type->createColumn(); + vec_column->reserve(sorted_results.view.size()); + for (const auto & row : sorted_results.view) + { + vec_index->get(row.rowid, ctx->vector_value); + vec_column->insertData( + reinterpret_cast(ctx->vector_value.data()), + ctx->vector_value.size() * sizeof(Float32)); + } + + ctx->perf->n_cf_reads += 1; + ctx->perf->total_cf_read_vec_ms += w.elapsedMillisecondsFromLastTime(); + + // read other column from ColumnFileTinyReader + // TODO: Optimize: We should be able to only read a few rows, instead of reading all rows then filter out. + Block block; + if (!ctx->rest_col_defs->empty()) + { + auto reader = tiny_file->getReader(*ctx->dm_context, ctx->data_provider, ctx->rest_col_defs, ctx->read_tag); + block = reader->readNextBlock(); + + ctx->filter.clear(); + ctx->filter.resize_fill(tiny_file->getRows(), 0); + for (const auto & row : sorted_results.view) + ctx->filter[row.rowid] = 1; + for (auto & col : block) + col.column = col.column->filter(ctx->filter, sorted_results.view.size()); + + RUNTIME_CHECK(block.rows() == sorted_results.view.size()); + } + + ctx->perf->total_cf_read_others_ms += w.elapsedMillisecondsFromLastTime(); + + auto index = ctx->header.getPositionByName(ctx->vec_cd.name); + block.insert(index, ColumnWithTypeAndName(std::move(vec_column), ctx->vec_cd.type, ctx->vec_cd.name)); + + // After a successful read, clear out the ordered_return_rows so that + // the next read will just return an empty block. + sorted_results.view = {}; + + return block; +} + +inline Block ColumnFileProvideVectorIndexInputStream::getHeader() const +{ + return ctx->header; +} + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ColumnFileInputStream.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ColumnFileInputStream.h new file mode 100644 index 00000000000..9a0618e1e77 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ColumnFileInputStream.h @@ -0,0 +1,73 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +namespace DB::DM +{ + +class ColumnFile; +using ColumnFilePtr = std::shared_ptr; +class ColumnFileTiny; +using ColumnFileTinyPtr = std::shared_ptr; + +class ColumnFileProvideVectorIndexInputStream + : public IProvideVectorIndex + , public NopSkippableBlockInputStream +{ +public: + static SkippableBlockInputStreamPtr createOrFallback( + const VectorIndexStreamCtxPtr & ctx, + const ColumnFilePtr & column_file); + + ColumnFileProvideVectorIndexInputStream(const VectorIndexStreamCtxPtr & ctx_, const ColumnFileTinyPtr & tiny_file_) + : ctx(ctx_) + , tiny_file(tiny_file_) + { + RUNTIME_CHECK(tiny_file != nullptr); + } + +public: // Implements IProvideVectorIndex + VectorIndexReaderPtr getVectorIndexReader() override; + + void setReturnRows(SearchResultView sorted_results_) override { sorted_results = sorted_results_; } + +public: // Implements IBlockInputStream + String getName() const override { return "VectorIndexColumnFile"; } + + Block getHeader() const override; + + // Note: The output block does not contain a start offset. + Block read() override; + +private: + // Note: Keep this struct small, because it will be constructed for each ColumnFileTiny who has index. + // If you have common things, put it in ctx. Only put things that are different by each ColumnFileTiny here. + + const VectorIndexStreamCtxPtr ctx; + const ColumnFileTinyPtr tiny_file; + + VectorIndexReaderPtr vec_index = nullptr; + + // Set by setReturnRows(), clear when a successful read() is done. + SearchResultView sorted_results; // Used to filter the output and learn what to read from VectorIndex +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ColumnFileInputStream_fwd.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ColumnFileInputStream_fwd.h new file mode 100644 index 00000000000..3d884c6b2a6 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ColumnFileInputStream_fwd.h @@ -0,0 +1,26 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace DB::DM +{ + +class ColumnFileProvideVectorIndexInputStream; + +using ColumnFileProvideVectorIndexInputStreamPtr = std::shared_ptr; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/Ctx.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/Ctx.cpp new file mode 100644 index 00000000000..1c2998d21d7 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/Ctx.cpp @@ -0,0 +1,106 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +namespace DB::DM +{ + +VectorIndexStreamCtxPtr VectorIndexStreamCtx::create( + const LocalIndexCachePtr & index_cache_light, // nullable + const LocalIndexCachePtr & index_cache_heavy, // nullable + const ANNQueryInfoPtr & ann_query_info, + const ColumnDefinesPtr & col_defs, + const IColumnFileDataProviderPtr & data_provider, + const DMContext & dm_context, + const ReadTag & read_tag) +{ + RUNTIME_CHECK(ann_query_info != nullptr); + RUNTIME_CHECK(data_provider != nullptr); + RUNTIME_CHECK(!col_defs->empty()); + + std::optional vec_cd; + auto rest_columns = std::make_shared(); + rest_columns->reserve(col_defs->size() - 1); + for (const auto & cd : *col_defs) + { + if (cd.id == ann_query_info->column_id()) + vec_cd.emplace(cd); + else + rest_columns->emplace_back(cd); + } + RUNTIME_CHECK(vec_cd.has_value()); + RUNTIME_CHECK(rest_columns->size() + 1 == col_defs->size()); + + auto header = toEmptyBlock(*col_defs); + + return std::make_shared(VectorIndexStreamCtx{ + .index_cache_light = index_cache_light, + .index_cache_heavy = index_cache_heavy, + .ann_query_info = ann_query_info, + .col_defs = col_defs, + .vec_cd = *vec_cd, + .rest_col_defs = rest_columns, + .header = header, + .data_provider = data_provider, + .dm_context = &dm_context, + .read_tag = read_tag, + .tracing_id = dm_context.tracing_id, + .perf = VectorIndexPerf::create(), + }); +} + +VectorIndexStreamCtxPtr VectorIndexStreamCtx::createForStableOnlyTests( + const ANNQueryInfoPtr & ann_query_info, + const ColumnDefinesPtr & col_defs, + const LocalIndexCachePtr & index_cache_light) +{ + RUNTIME_CHECK(ann_query_info != nullptr); + RUNTIME_CHECK(!col_defs->empty()); + + std::optional vec_cd; + auto rest_columns = std::make_shared(); + rest_columns->reserve(col_defs->size() - 1); + for (const auto & cd : *col_defs) + { + if (cd.id == ann_query_info->column_id()) + vec_cd.emplace(cd); + else + rest_columns->emplace_back(cd); + } + RUNTIME_CHECK(vec_cd.has_value()); + RUNTIME_CHECK(rest_columns->size() + 1 == col_defs->size()); + + auto header = toEmptyBlock(*col_defs); + + return std::make_shared(VectorIndexStreamCtx{ + .index_cache_light = index_cache_light, + .index_cache_heavy = nullptr, + .ann_query_info = ann_query_info, + .col_defs = col_defs, + .vec_cd = *vec_cd, + .rest_col_defs = rest_columns, + .header = header, + .data_provider = nullptr, + .dm_context = nullptr, + .read_tag = ReadTag::Internal, + .tracing_id = "", + .perf = VectorIndexPerf::create(), + }); +} + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/Ctx.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/Ctx.h new file mode 100644 index 00000000000..9132caa0005 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/Ctx.h @@ -0,0 +1,82 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB::DM +{ + +/// Some commonly shared things through out a single vector search process. +/// Its expected lifetime >= VectorIndexInputStream. +struct VectorIndexStreamCtx +{ + const LocalIndexCachePtr index_cache_light; // nullable + const LocalIndexCachePtr index_cache_heavy; // nullable + const ANNQueryInfoPtr ann_query_info; + const ColumnDefinesPtr col_defs; + const ColumnDefine vec_cd; + const ColumnDefinesPtr rest_col_defs; + const Block header; + + // ============================================================ + // Fields below are for accessing ColumnFile + const IColumnFileDataProviderPtr data_provider; + + /// WARN: Do not use this in cases other than building a ColumnFileInputStream. + /// Because this field is not set in DMFile only tests. + /// FIXME: This pointer is also definitely dangerous. We cannot guarantee the lifetime. + const DMContext * dm_context; + + const ReadTag read_tag; + // ============================================================ + + /// Note: This field is also available in dm_context. However dm_context could be null in tests. + const String tracing_id; + + // ============================================================ + // Fields below are mutable and shared without any lock, because our input streams + // will only operate one by one, and there is no contention in each read() + const VectorIndexPerfPtr perf; // perf is modifyable + std::vector vector_value; + /// reused in each read() + IColumn::Filter filter; + // ============================================================ + + + static VectorIndexStreamCtxPtr create( + const LocalIndexCachePtr & index_cache_light_, + const LocalIndexCachePtr & index_cache_heavy_, + const ANNQueryInfoPtr & ann_query_info_, + const ColumnDefinesPtr & col_defs_, + const IColumnFileDataProviderPtr & data_provider_, // Must provide for this interface + const DMContext & dm_context_, + const ReadTag & read_tag_); + + // Only used in tests! + static VectorIndexStreamCtxPtr createForStableOnlyTests( + const ANNQueryInfoPtr & ann_query_info_, + const ColumnDefinesPtr & col_defs_, + const LocalIndexCachePtr & index_cache_light_ = nullptr); +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/Ctx_fwd.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/Ctx_fwd.h new file mode 100644 index 00000000000..6a0ee6b4eb9 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/Ctx_fwd.h @@ -0,0 +1,28 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include + +namespace DB::DM +{ + +struct VectorIndexStreamCtx; + +using VectorIndexStreamCtxPtr = std::shared_ptr; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/DMFileInputStream.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/DMFileInputStream.cpp new file mode 100644 index 00000000000..2cc5dad99b7 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/DMFileInputStream.cpp @@ -0,0 +1,212 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + + +namespace DB::DM +{ + +DMFileInputStreamProvideVectorIndex::DMFileInputStreamProvideVectorIndex( + const VectorIndexStreamCtxPtr & ctx_, + const DMFilePtr & dmfile_, + DMFileReader && rest_col_reader_) + : ctx(ctx_) + , dmfile(dmfile_) + , rest_col_reader(std::move(rest_col_reader_)) +{ + RUNTIME_CHECK(dmfile != nullptr); +} + +Block DMFileInputStreamProvideVectorIndex::read() +{ + // We expect setReturnRows() is called before doing any read(). + RUNTIME_CHECK(sorted_results.owner != nullptr); + RUNTIME_CHECK(vec_index != nullptr); + + const auto sorted_results_view = sorted_results.view; + + if (rest_col_reader.read_block_infos.empty()) + return {}; + + const auto [start_pack_id, pack_count, rs_result, read_rows] = rest_col_reader.read_block_infos.front(); + const auto start_row_offset = rest_col_reader.pack_offset[start_pack_id]; + + auto begin = std::lower_bound( // + sorted_results_view.begin(), + sorted_results_view.end(), + start_row_offset, + [](const auto & lhs, const auto & rhs) { return lhs.rowid < rhs; }); + auto end = std::lower_bound( // + begin, + sorted_results_view.end(), + start_row_offset + read_rows, + [](const auto & lhs, const auto & rhs) { return lhs.rowid < rhs; }); + const std::span block_selected_rows{begin, end}; + if (block_selected_rows.empty()) + return {}; + + Stopwatch w(CLOCK_MONOTONIC_COARSE); + + auto vec_column = ctx->vec_cd.type->createColumn(); + vec_column->reserve(block_selected_rows.size()); + for (const auto & row : block_selected_rows) + { + vec_index->get(row.rowid, ctx->vector_value); + vec_column->insertData( + reinterpret_cast(ctx->vector_value.data()), + ctx->vector_value.size() * sizeof(Float32)); + } + + ctx->perf->n_dm_reads += 1; + ctx->perf->total_dm_read_vec_ms += w.elapsedMillisecondsFromLastTime(); + + Block block; + + // read other columns if needed + if (!rest_col_reader.read_columns.empty()) + { + ctx->filter.clear(); + ctx->filter.resize_fill(read_rows, 0); + for (const auto & row : block_selected_rows) + ctx->filter[row.rowid - start_row_offset] = 1; + + block = rest_col_reader.read(); + for (auto & col : block) + col.column = col.column->filter(ctx->filter, block_selected_rows.size()); + } + else + { + // Since we do not call `reader.read()` here, we need to pop the read_block_infos manually. + rest_col_reader.read_block_infos.pop_front(); + } + + ctx->perf->total_dm_read_others_ms += w.elapsedMillisecondsFromLastTime(); + + const auto vec_col_pos = ctx->header.getPositionByName(ctx->vec_cd.name); + block.insert( + vec_col_pos, + ColumnWithTypeAndName{std::move(vec_column), ctx->vec_cd.type, ctx->vec_cd.name, ctx->vec_cd.id}); + block.setStartOffset(start_row_offset); + block.setRSResult(rs_result); + return block; +} + +inline VectorIndexReaderPtr DMFileInputStreamProvideVectorIndex::getVectorIndexReader() +{ + if (vec_index != nullptr) + return vec_index; + vec_index = VectorIndexReaderFromDMFile::load(ctx, dmfile); + return vec_index; +} + +void DMFileInputStreamProvideVectorIndex::setReturnRows(IProvideVectorIndex::SearchResultView sorted_results_) +{ + sorted_results = sorted_results_; + const auto sorted_results_view = sorted_results.view; + + // Vector index is very likely to filter out some packs. For example, + // if we query for Top 1, then only 1 pack will be remained. So we + // update the reader's read_block_infos to avoid reading unnecessary data for other columns. + + // The following logic is nearly the same with DMFileReader::initReadBlockInfos. + + auto & read_block_infos = rest_col_reader.read_block_infos; + const auto & pack_offset = rest_col_reader.pack_offset; + + read_block_infos.clear(); + const auto & pack_stats = dmfile->getPackStats(); + const auto & pack_res = rest_col_reader.pack_filter->getPackRes(); + + // Update valid_packs_before_search + { + ctx->perf->n_dm_searches += 1; + ctx->perf->dm_packs_in_file += pack_stats.size(); + for (const auto res : pack_res) + ctx->perf->dm_packs_before_search += res.isUse(); + } + + // Update read_block_infos + size_t start_pack_id = 0; + size_t read_rows = 0; + auto prev_block_pack_res = RSResult::All; + auto sorted_results_it = sorted_results_view.begin(); + size_t pack_id = 0; + for (; pack_id < pack_stats.size(); ++pack_id) + { + if (sorted_results_it == sorted_results_view.end()) + break; + const auto begin = std::lower_bound( // + sorted_results_it, + sorted_results_view.end(), + pack_offset[pack_id], + [](const auto & lhs, const auto & rhs) { return lhs.rowid < rhs; }); + const auto end = std::lower_bound( // + begin, + sorted_results_view.end(), + pack_offset[pack_id] + pack_stats[pack_id].rows, + [](const auto & lhs, const auto & rhs) { return lhs.rowid < rhs; }); + bool is_use = begin != end; + bool reach_limit = read_rows >= rest_col_reader.rows_threshold_per_read; + bool break_all_match = prev_block_pack_res.allMatch() && !pack_res[pack_id].allMatch() + && read_rows >= rest_col_reader.rows_threshold_per_read / 2; + + if (!is_use) + { + if (read_rows > 0) + read_block_infos.emplace_back(start_pack_id, pack_id - start_pack_id, prev_block_pack_res, read_rows); + start_pack_id = pack_id + 1; + read_rows = 0; + prev_block_pack_res = RSResult::All; + } + else if (reach_limit || break_all_match) + { + if (read_rows > 0) + read_block_infos.emplace_back(start_pack_id, pack_id - start_pack_id, prev_block_pack_res, read_rows); + start_pack_id = pack_id; + read_rows = pack_stats[pack_id].rows; + prev_block_pack_res = pack_res[pack_id]; + } + else + { + prev_block_pack_res = prev_block_pack_res && pack_res[pack_id]; + read_rows += pack_stats[pack_id].rows; + } + + sorted_results_it = end; + } + if (read_rows > 0) + read_block_infos.emplace_back(start_pack_id, pack_id - start_pack_id, prev_block_pack_res, read_rows); + + // Update valid_packs_after_search + { + for (const auto & block_info : read_block_infos) + ctx->perf->dm_packs_after_search += block_info.pack_count; + } + + RUNTIME_CHECK_MSG(sorted_results_it == sorted_results_view.end(), "All results are not consumed"); +} + +inline Block DMFileInputStreamProvideVectorIndex::getHeader() const +{ + return ctx->header; +} + + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/DMFileInputStream.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/DMFileInputStream.h new file mode 100644 index 00000000000..f039b9d0675 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/DMFileInputStream.h @@ -0,0 +1,88 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + + +namespace DB::DM +{ + +/** + * @brief DMFileInputStreamProvideVectorIndex is similar to DMFileBlockInputStream. + * However it can read data efficiently with the help of vector index. + * + * General steps: + * 1. Read all PK, Version and Del Marks (respecting Pack filters). + * 2. Construct a bitmap of valid rows (in memory). This bitmap guides the reading of vector index to determine whether a row is valid or not. + * 3. Perform a vector search for Top K vector rows. We now have K row_ids whose vector distance is close. + * 4. Map these row_ids to packids as the new pack filter. + * 5. Read from other columns with the new pack filter. + * For each read, join other columns and vector column together. + * + * Step 3~4 is performed lazily at first read. + * + * Before constructing this class, the caller must ensure that vector index + * exists on the corresponding column. If the index does not exist, the caller + * should use the standard DMFileBlockInputStream. + */ +class DMFileInputStreamProvideVectorIndex + : public IProvideVectorIndex + , public NopSkippableBlockInputStream +{ +public: + static auto create(const VectorIndexStreamCtxPtr & ctx, const DMFilePtr & dmfile, DMFileReader && rest_col_reader) + { + return std::make_shared(ctx, dmfile, std::move(rest_col_reader)); + } + + explicit DMFileInputStreamProvideVectorIndex( + const VectorIndexStreamCtxPtr & ctx_, + const DMFilePtr & dmfile_, + DMFileReader && rest_col_reader_); + +public: // Implements IProvideVectorIndex + VectorIndexReaderPtr getVectorIndexReader() override; + + void setReturnRows(IProvideVectorIndex::SearchResultView sorted_results) override; + +public: // Implements IBlockInputStream + Block read() override; + + String getName() const override { return "VectorIndexDMFile"; } + + Block getHeader() const override; + +private: + // Update the read_block_infos according to the sorted_results. + void updateReadBlockInfos(); + +private: + const VectorIndexStreamCtxPtr ctx; + const DMFilePtr dmfile; + VectorIndexReaderPtr vec_index = nullptr; + // Vector column should be excluded in the reader + DMFileReader rest_col_reader; + + /// Set after calling setReturnRows + IProvideVectorIndex::SearchResultView sorted_results; +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/DMFileInputStream_fwd.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/DMFileInputStream_fwd.h new file mode 100644 index 00000000000..f4109e653d1 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/DMFileInputStream_fwd.h @@ -0,0 +1,26 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace DB::DM +{ + +class DMFileInputStreamProvideVectorIndex; + +using DMFileInputStreamProvideVectorIndexPtr = std::shared_ptr; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/IProvideVectorIndex.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/IProvideVectorIndex.h new file mode 100644 index 00000000000..243fca7581f --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/IProvideVectorIndex.h @@ -0,0 +1,60 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include +#include +#include + +namespace DB::DM +{ + +/** + * @brief If some InputStream is capable of providing and filtering via a VectorIndex, + * then it should inherit and implement this class. + * InputStream could then only return a subset rows of the matched rows, thank to joining multiple TopK results + * via VectorIndexInputStream. + */ +class IProvideVectorIndex +{ +public: + struct SearchResult + { + UInt32 rowid{}; // Always local + Float32 distance{}; + }; + + struct SearchResultView + { + std::shared_ptr> owner = nullptr; + std::span view = {}; + }; + +public: + virtual ~IProvideVectorIndex() = default; + + /// Returns a VectorIndexReader from the current BlockInputStream. + virtual VectorIndexReaderPtr getVectorIndexReader() = 0; + + /// This inputStream must only return these rows as the final result. + /// This is always called before the first read(). + /// `return_rows` is ensured to be sorted and does not contain duplicates. + virtual void setReturnRows(SearchResultView sorted_results) = 0; +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/InputStream.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/InputStream.cpp new file mode 100644 index 00000000000..3653f45bbe6 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/InputStream.cpp @@ -0,0 +1,230 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB::DM +{ + +void VectorIndexInputStream::initSearchResults() +{ + if (searchResultsInited) + return; + + UInt32 precedes_rows = 0; + auto search_results = std::make_shared>(); + search_results->reserve(ctx->ann_query_info->top_k()); + + // 1. Do vector search for all index streams. + for (size_t i = 0, i_max = stream->children.size(); i < i_max; ++i) + { + if (auto * index_stream = index_streams[i]; index_stream) + { + auto reader = index_stream->getVectorIndexReader(); + RUNTIME_CHECK(reader != nullptr); + auto current_filter = BitmapFilterView(bitmap_filter, precedes_rows, stream->rows[i]); + auto results = reader->search(ctx->ann_query_info, current_filter); + const size_t results_n = results.size(); + VectorIndexReader::Key last_rowid = std::numeric_limits::max(); + for (size_t i = 0; i < results_n; ++i) + { + const auto rowid = results[i].member.key; + if (rowid == last_rowid) + continue; // Perform a very simple deduplicate + last_rowid = rowid; + // The result from usearch may contain filtered out rows so we filter again + if (current_filter[rowid]) + search_results->emplace_back(IProvideVectorIndex::SearchResult{ + // We need to sort globally so convert it to a global offset temporarily. + // We will convert it back to local offset when we feed it back to substreams. + .rowid = rowid + precedes_rows, + .distance = results[i].distance, + }); + } + } + precedes_rows += stream->rows[i]; + } + + // 2. Keep the top k minimum distances rows. + // [0, top_k) will be the top k minimum distances rows. (However it is not sorted) + const auto top_k = std::min(search_results->size(), ctx->ann_query_info->top_k()); + std::nth_element( // + search_results->begin(), + search_results->begin() + top_k, + search_results->end(), + [](const auto & lhs, const auto & rhs) { return lhs.distance < rhs.distance; }); + search_results->resize(top_k); + + // 3. Sort by rowid for the first K rows. + std::sort( // + search_results->begin(), + search_results->end(), + [](const auto & lhs, const auto & rhs) { return lhs.rowid < rhs.rowid; }); + + // 4. Finally, notify all index streams to only return these rows. + precedes_rows = 0; + auto sr_it = search_results->begin(); + for (size_t i = 0, i_max = stream->children.size(); i < i_max; ++i) + { + if (auto * index_stream = index_streams[i]; index_stream) + { + auto reader = index_stream->getVectorIndexReader(); + RUNTIME_CHECK(reader != nullptr); + auto begin = std::lower_bound( // + sr_it, + search_results->end(), + precedes_rows, + [](const auto & lhs, const auto & rhs) { return lhs.rowid < rhs; }); + auto end = std::lower_bound( // + begin, + search_results->end(), + precedes_rows + stream->rows[i], + [](const auto & lhs, const auto & rhs) { return lhs.rowid < rhs; }); + // Convert back to local offset. + for (auto it = begin; it != end; ++it) + it->rowid -= precedes_rows; + index_stream->setReturnRows(IProvideVectorIndex::SearchResultView{ + .owner = search_results, + .view = std::span{ + &*begin, + static_cast(std::distance(begin, end))}}); + sr_it = end; + } + precedes_rows += stream->rows[i]; + } + + if (ctx->dm_context != nullptr && ctx->dm_context->scan_context != nullptr) + { + auto scan_context = ctx->dm_context->scan_context; + scan_context->total_vector_idx_load_from_s3 += ctx->perf->load_from_stable_s3; + scan_context->total_vector_idx_load_from_disk + += ctx->perf->load_from_stable_disk + ctx->perf->load_from_column_file; + scan_context->total_vector_idx_load_from_cache += ctx->perf->load_from_cache; + scan_context->total_vector_idx_load_time_ms += ctx->perf->total_load_ms; + scan_context->total_vector_idx_search_time_ms += ctx->perf->total_search_ms; + scan_context->total_vector_idx_search_visited_nodes += ctx->perf->visited_nodes; + scan_context->total_vector_idx_search_discarded_nodes += ctx->perf->discarded_nodes; + } + + searchResultsInited = true; +} + +VectorIndexInputStream::~VectorIndexInputStream() +{ + LOG_DEBUG( + log, + "Vector search reading finished, " + "load_index={:.3f}s (from:[cf/dmf]={}/{} noindex:[cf/dmf]={}/{} [cached/cf_data/dmf_disk/dmf_s3]={}/{}/{}/{}), " + "vec_search={:.3f}s, " + "vec_get_[cf/dmf]={:.3f}s/{:.3f}s, " + "other_get_[cf/dmf]={:.3f}s/{:.3f}s, " + "pack_[before/after]={}/{}, " + "top_k_[query/visited/discarded/result]={}/{}/{}/{}", + static_cast(ctx->perf->total_load_ms) / 1000.0, + ctx->perf->n_from_cf_index, + ctx->perf->n_from_dmf_index, + ctx->perf->n_from_cf_noindex, + ctx->perf->n_from_dmf_noindex, + ctx->perf->load_from_cache, + ctx->perf->load_from_column_file, + ctx->perf->load_from_stable_disk, + ctx->perf->load_from_stable_s3, + static_cast(ctx->perf->total_search_ms) / 1000.0, + static_cast(ctx->perf->total_cf_read_vec_ms) / 1000.0, + static_cast(ctx->perf->total_dm_read_vec_ms) / 1000.0, + static_cast(ctx->perf->total_cf_read_others_ms) / 1000.0, + static_cast(ctx->perf->total_dm_read_others_ms) / 1000.0, + ctx->perf->dm_packs_before_search, + ctx->perf->dm_packs_after_search, + ctx->ann_query_info->top_k(), + ctx->perf->visited_nodes, + ctx->perf->discarded_nodes, + ctx->perf->returned_nodes); +} + +Block VectorIndexInputStream::read() +{ + initSearchResults(); + + auto block = stream->read(); + if (!block) + { + onReadFinished(); + return {}; + } + + // The block read from `IProvideVectorIndex` will only return the selected rows after global TopK. Return it directly. + // MVCC has already been done when searching the vector index. + // For streams which are not `IProvideVectorIndex`, the block should be filtered by MVCC bitmap. + if (auto idx = std::distance(stream->children.begin(), stream->current_stream); !index_streams[idx]) + { + ctx->filter.resize(block.rows()); + if (bool all_match = bitmap_filter->get(ctx->filter, block.startOffset(), block.rows()); all_match) + return block; + + size_t passed_count = countBytesInFilter(ctx->filter); + for (auto & col : block) + col.column = col.column->filter(ctx->filter, passed_count); + } + + return block; +} + +void VectorIndexInputStream::onReadFinished() +{ + if (isReadFinished) + return; + isReadFinished = true; + + // For some reason it is too late if we report this in the destructor. + if (ctx->dm_context != nullptr && ctx->dm_context->scan_context != nullptr) + { + auto scan_context = ctx->dm_context->scan_context; + scan_context->total_vector_idx_read_vec_time_ms + += ctx->perf->total_cf_read_vec_ms + ctx->perf->total_dm_read_vec_ms; + scan_context->total_vector_idx_read_others_time_ms + += ctx->perf->total_cf_read_others_ms + ctx->perf->total_dm_read_others_ms; + } +} + +VectorIndexInputStream::VectorIndexInputStream( + const VectorIndexStreamCtxPtr & ctx_, + const BitmapFilterPtr & bitmap_filter_, + std::shared_ptr> stream_) + : ctx(ctx_) + , bitmap_filter(bitmap_filter_) + , stream(stream_) + , log(Logger::get(ctx->tracing_id)) +{ + RUNTIME_CHECK(bitmap_filter != nullptr); + RUNTIME_CHECK(stream != nullptr); + + index_streams.reserve(stream->children.size()); + for (const auto & sub_stream : stream->children) + { + if (auto * index_stream = dynamic_cast(sub_stream.get()); index_stream) + index_streams.push_back(index_stream); + else + index_streams.push_back(nullptr); + } +} + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/InputStream.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/InputStream.h new file mode 100644 index 00000000000..e0d25a11cf4 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/InputStream.h @@ -0,0 +1,78 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace DB::DM +{ + +class BitmapFilter; +using BitmapFilterPtr = std::shared_ptr; + +/** + * @brief Unifies multiple sub-streams that supports reading the vector index. + * It first performs a vector search over index for all substreams. + * Then, a final TopK result is formed. + * After that, it reads the actual data with only selected TopK rows from these substreams. + * This class ensures the output result always respect the BitmapFilter. + */ +class VectorIndexInputStream : public NopSkippableBlockInputStream +{ +public: + static auto create( + const VectorIndexStreamCtxPtr & ctx_, + const BitmapFilterPtr & bitmap_filter_, + std::shared_ptr> stream_) + { + return std::make_shared(ctx_, bitmap_filter_, stream_); + } + + VectorIndexInputStream( + const VectorIndexStreamCtxPtr & ctx_, + const BitmapFilterPtr & bitmap_filter_, + std::shared_ptr> stream); + + ~VectorIndexInputStream() override; + +public: // Implements IBlockInputStream + String getName() const override { return "VectorIndexConcat"; } + + Block getHeader() const override { return stream->getHeader(); } + + Block read() override; + +private: + void onReadFinished(); + bool isReadFinished = false; + +private: + const VectorIndexStreamCtxPtr ctx; + const BitmapFilterPtr bitmap_filter; + const std::shared_ptr> stream; + // Assigned in the constructor. Pointers to stream's children, nullptr if the child is not a VectorIndexBlockInputStream. + std::vector index_streams; + + const LoggerPtr log; + + /// Before returning any actual data, we first perform a vector search over index for substreams. + void initSearchResults(); + bool searchResultsInited = false; +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/InputStream_fwd.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/InputStream_fwd.h new file mode 100644 index 00000000000..00ec6e9a968 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/InputStream_fwd.h @@ -0,0 +1,26 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace DB::DM +{ + +class VectorIndexInputStream; + +using VectorIndexInputStreamPtr = std::shared_ptr; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromColumnFileTiny.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromColumnFileTiny.cpp new file mode 100644 index 00000000000..c41ac501f1d --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromColumnFileTiny.cpp @@ -0,0 +1,85 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB::DM +{ + +VectorIndexReaderPtr VectorIndexReaderFromColumnFileTiny::load( + const VectorIndexStreamCtxPtr & ctx, + const ColumnFileTiny & tiny_file) +{ + Stopwatch w(CLOCK_MONOTONIC_COARSE); + + auto const * index_info = tiny_file.findIndexInfo(ctx->ann_query_info->index_id()); + RUNTIME_CHECK(index_info != nullptr); + RUNTIME_CHECK(index_info->index_props().kind() == dtpb::IndexFileKind::VECTOR_INDEX); + RUNTIME_CHECK(index_info->index_props().has_vector_index()); + + auto index_page_id = index_info->index_page_id(); + + bool is_load_from_storage = false; + auto load_from_page_storage = [&]() { + is_load_from_storage = true; + std::vector index_fields = {0}; + auto index_page = ctx->data_provider->readTinyData(index_page_id, index_fields); + ReadBufferFromOwnString read_buf(index_page.data); + CompressedReadBuffer compressed(read_buf); + return VectorIndexReader::createFromMemory(index_info->index_props().vector_index(), ctx->perf, compressed); + }; + + VectorIndexReaderPtr vec_index = nullptr; + // ColumnFile vector index stores all data in memory, can not be evicted by system, so use heavy cache. + if (ctx->index_cache_heavy) + { + const auto key = fmt::format("{}{}", LocalIndexCache::COLUMNFILETINY_INDEX_NAME_PREFIX, index_page_id); + auto local_index = ctx->index_cache_heavy->getOrSet(key, load_from_page_storage); + vec_index = std::dynamic_pointer_cast(local_index); + } + else + vec_index = load_from_page_storage(); + + RUNTIME_CHECK(vec_index != nullptr); + + { // Statistics + double elapsed = w.elapsedSeconds(); + if (is_load_from_storage) + { + ctx->perf->load_from_column_file += 1; + GET_METRIC(tiflash_vector_index_duration, type_load_cf).Observe(elapsed); + } + else + { + ctx->perf->load_from_cache += 1; + GET_METRIC(tiflash_vector_index_duration, type_load_cache).Observe(elapsed); + } + ctx->perf->total_load_ms += elapsed * 1000; + } + + return vec_index; +} + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromColumnFileTiny.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromColumnFileTiny.h new file mode 100644 index 00000000000..051d37b5a63 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromColumnFileTiny.h @@ -0,0 +1,31 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace DB::DM +{ + +class ColumnFileTiny; + +class VectorIndexReaderFromColumnFileTiny +{ +public: + static VectorIndexReaderPtr load(const VectorIndexStreamCtxPtr & ctx, const ColumnFileTiny & tiny_file); +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromDMFile.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromDMFile.cpp new file mode 100644 index 00000000000..6ce567b1828 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromDMFile.cpp @@ -0,0 +1,150 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ +extern const int S3_ERROR; +} // namespace DB::ErrorCodes + +namespace DB::DM +{ + +VectorIndexReaderPtr VectorIndexReaderFromDMFile::load(const VectorIndexStreamCtxPtr & ctx, const DMFilePtr & dmfile) +{ + Stopwatch w(CLOCK_MONOTONIC_COARSE); + + const auto col_id = ctx->ann_query_info->column_id(); + const auto index_id = ctx->ann_query_info->index_id(); + + RUNTIME_CHECK(dmfile->useMetaV2()); // v3 + + // Check vector index exists on the column + auto vector_index = dmfile->getLocalIndex(col_id, index_id); + RUNTIME_CHECK(vector_index.has_value(), col_id, index_id); + RUNTIME_CHECK(vector_index->index_props().kind() == dtpb::IndexFileKind::VECTOR_INDEX); + RUNTIME_CHECK(vector_index->index_props().has_vector_index()); + + bool has_s3_download = false; + bool has_load_from_file = false; + + // If local file is invalidated, cache is not valid anymore. So we + // need to ensure file exists on local fs first. + const auto index_file_path = index_id > 0 // + ? dmfile->localIndexPath(index_id, TiDB::ColumnarIndexKind::Vector) // + : dmfile->colIndexPath(DMFile::getFileNameBase(col_id)); + String local_index_file_path; + if (auto s3_file_name = S3::S3FilenameView::fromKeyWithPrefix(index_file_path); s3_file_name.isValid()) + { + // Disaggregated mode + auto * file_cache = FileCache::instance(); + RUNTIME_CHECK_MSG(file_cache, "Must enable S3 file cache to use vector index"); + + auto perf_begin = PerfContext::file_cache; + + // If download file failed, retry a few times. + for (auto i = 3; i > 0; --i) + { + try + { + if (auto file_guard = file_cache->downloadFileForLocalRead( // + s3_file_name, + vector_index->index_props().file_size()); + file_guard) + { + local_index_file_path = file_guard->getLocalFileName(); + break; // Successfully downloaded index into local cache + } + + throw Exception(ErrorCodes::S3_ERROR, "Failed to download vector index file {}", index_file_path); + } + catch (...) + { + if (i <= 1) + throw; + } + } + + if ( // + PerfContext::file_cache.fg_download_from_s3 > perf_begin.fg_download_from_s3 || // + PerfContext::file_cache.fg_wait_download_from_s3 > perf_begin.fg_wait_download_from_s3) + has_s3_download = true; + } + else + { + // Not disaggregated mode + local_index_file_path = index_file_path; + } + + auto load_from_file = [&]() { + has_load_from_file = true; + return VectorIndexReader::createFromMmap( + vector_index->index_props().vector_index(), + ctx->perf, + local_index_file_path); + }; + + VectorIndexReaderPtr vec_index = nullptr; + // DMFile vector index uses mmap to read data, does not directly occupy memory, so use the light cache. + if (ctx->index_cache_light) + { + // Note: must use local_index_file_path as the cache key, because cache + // will check whether file is still valid and try to remove memory references + // when file is dropped. + auto local_index = ctx->index_cache_light->getOrSet(local_index_file_path, load_from_file); + vec_index = std::dynamic_pointer_cast(local_index); + } + else + vec_index = load_from_file(); + + RUNTIME_CHECK(vec_index != nullptr); + + { // Statistics + double elapsed = w.elapsedSeconds(); + if (has_s3_download) + { + // it could be possible that s3=true but load_from_file=false, it means we download a file + // and then reuse the memory cache. The majority time comes from s3 download + // so we still count it as s3 download. + ctx->perf->load_from_stable_s3 += 1; + GET_METRIC(tiflash_vector_index_duration, type_load_dmfile_s3).Observe(elapsed); + } + else if (has_load_from_file) + { + ctx->perf->load_from_stable_disk += 1; + GET_METRIC(tiflash_vector_index_duration, type_load_dmfile_local).Observe(elapsed); + } + else + { + ctx->perf->load_from_cache += 1; + GET_METRIC(tiflash_vector_index_duration, type_load_cache).Observe(elapsed); + } + ctx->perf->total_load_ms += elapsed * 1000; + } + + return vec_index; +} + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromDMFile.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromDMFile.h new file mode 100644 index 00000000000..67df3c20835 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Stream/ReaderFromDMFile.h @@ -0,0 +1,30 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace DB::DM +{ + +class VectorIndexReaderFromDMFile +{ +public: + static VectorIndexReaderPtr load(const VectorIndexStreamCtxPtr & ctx, const DMFilePtr & dmfile); +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Writer.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Writer.cpp index 68ad7491258..685d69f66c1 100644 --- a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Writer.cpp +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Writer.cpp @@ -20,9 +20,8 @@ #include #include #include -#include #include -#include + namespace DB::ErrorCodes { @@ -36,13 +35,16 @@ namespace DB::DM VectorIndexWriterInternal::VectorIndexWriterInternal(const TiDB::VectorIndexDefinitionPtr & definition_) : definition(definition_) - , index(USearchImplType::make(unum::usearch::metric_punned_t( // - definition_->dimension, - getUSearchMetricKind(definition->distance_metric)))) { - RUNTIME_CHECK(definition_->kind == tipb::VectorIndexKind::HNSW); + RUNTIME_CHECK(definition != nullptr); + RUNTIME_CHECK(definition->kind == tipb::VectorIndexKind::HNSW); RUNTIME_CHECK(definition->dimension > 0); RUNTIME_CHECK(definition->dimension <= TiDB::MAX_VECTOR_DIMENSION); + + index = USearchImplType::make(unum::usearch::metric_punned_t( // + definition->dimension, + getUSearchMetricKind(definition->distance_metric))); + GET_METRIC(tiflash_vector_index_active_instances, type_build).Increment(); } @@ -65,7 +67,7 @@ void VectorIndexWriterInternal::addBlock( index.reserve(unum::usearch::ceil2(index.size() + column.size())); - Stopwatch w; + Stopwatch w(CLOCK_MONOTONIC_COARSE); SCOPE_EXIT({ total_duration += w.elapsedSeconds(); }); Stopwatch w_proceed_check(CLOCK_MONOTONIC_COARSE); @@ -112,33 +114,42 @@ void VectorIndexWriterInternal::addBlock( last_reported_memory_usage = current_memory_usage; } -void VectorIndexWriterInternal::saveToFile(std::string_view path) const +VectorIndexWriterInternal::~VectorIndexWriterInternal() { - Stopwatch w; - SCOPE_EXIT({ total_duration += w.elapsedSeconds(); }); + GET_METRIC(tiflash_vector_index_duration, type_build).Observe(total_duration); + GET_METRIC(tiflash_vector_index_memory_usage, type_build) + .Decrement(static_cast(last_reported_memory_usage)); + GET_METRIC(tiflash_vector_index_active_instances, type_build).Decrement(); +} - auto result = index.save(unum::usearch::output_file_t(path.data())); - RUNTIME_CHECK_MSG(result, "Failed to save vector index: {} path={}", result.error.what(), path); +void VectorIndexWriterInternal::saveFileProps(dtpb::IndexFilePropsV2 * pb_idx) const +{ + auto * pb_vec_idx = pb_idx->mutable_vector_index(); + pb_vec_idx->set_format_version(0); + pb_vec_idx->set_dimensions(definition->dimension); + pb_vec_idx->set_distance_metric(tipb::VectorDistanceMetric_Name(definition->distance_metric)); } -void VectorIndexWriterInternal::saveToBuffer(WriteBuffer & write_buf) const +void VectorIndexWriterOnDisk::saveToFile() const { - Stopwatch w; - SCOPE_EXIT({ total_duration += w.elapsedSeconds(); }); + Stopwatch w(CLOCK_MONOTONIC_COARSE); + SCOPE_EXIT({ writer.total_duration += w.elapsedSeconds(); }); - auto result = index.save_to_stream([&](void const * buffer, std::size_t length) { + auto result = writer.index.save(unum::usearch::output_file_t(index_file.data())); + RUNTIME_CHECK_MSG(result, "Failed to save vector index: {} path={}", result.error.what(), index_file); +} + +void VectorIndexWriterInMemory::saveToBuffer(WriteBuffer & write_buf) +{ + Stopwatch w(CLOCK_MONOTONIC_COARSE); + SCOPE_EXIT({ writer.total_duration += w.elapsedSeconds(); }); + + auto result = writer.index.save_to_stream([&](void const * buffer, std::size_t length) { write_buf.write(reinterpret_cast(buffer), length); return true; }); + write_buf.next(); RUNTIME_CHECK_MSG(result, "Failed to save vector index: {}", result.error.what()); } -VectorIndexWriterInternal::~VectorIndexWriterInternal() -{ - GET_METRIC(tiflash_vector_index_duration, type_build).Observe(total_duration); - GET_METRIC(tiflash_vector_index_memory_usage, type_build) - .Decrement(static_cast(last_reported_memory_usage)); - GET_METRIC(tiflash_vector_index_active_instances, type_build).Decrement(); -} - } // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Writer.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Writer.h index 22839415e27..715706c04f4 100644 --- a/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Writer.h +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex/Writer.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -26,69 +27,82 @@ namespace DB::DM class VectorIndexWriterInternal { + friend class VectorIndexWriterInMemory; + friend class VectorIndexWriterOnDisk; + protected: using USearchImplType = unum::usearch::index_dense_gt; - void saveToFile(std::string_view path) const; - void saveToBuffer(WriteBuffer & write_buf) const; - public: /// The key is the row's offset in the DMFile. using Key = UInt32; - using ProceedCheckFn = std::function; explicit VectorIndexWriterInternal(const TiDB::VectorIndexDefinitionPtr & definition_); - virtual ~VectorIndexWriterInternal(); + ~VectorIndexWriterInternal(); + using ProceedCheckFn = LocalIndexWriter::ProceedCheckFn; void addBlock(const IColumn & column, const ColumnVector * del_mark, ProceedCheckFn should_proceed); + void saveFileProps(dtpb::IndexFilePropsV2 * pb_idx) const; public: const TiDB::VectorIndexDefinitionPtr definition; private: - USearchImplType index; UInt64 added_rows = 0; // Includes nulls and deletes. Used as the index key. - - mutable double total_duration = 0; size_t last_reported_memory_usage = 0; + USearchImplType index; + mutable double total_duration = 0; }; -class VectorIndexWriterInMemory : public VectorIndexWriterInternal +class VectorIndexWriterInMemory : public LocalIndexWriterInMemory { public: - explicit VectorIndexWriterInMemory(const TiDB::VectorIndexDefinitionPtr & definition_) - : VectorIndexWriterInternal(definition_) + explicit VectorIndexWriterInMemory(IndexID index_id, const TiDB::VectorIndexDefinitionPtr & definition) + : LocalIndexWriterInMemory(index_id) + , writer(definition) {} - static VectorIndexWriterInMemoryPtr create(const TiDB::VectorIndexDefinitionPtr & definition) + void saveToBuffer(WriteBuffer & write_buf) override; + + void addBlock(const IColumn & column, const ColumnVector * del_mark, ProceedCheckFn should_proceed) override { - return std::make_shared(definition); + writer.addBlock(column, del_mark, should_proceed); } - void finalize(WriteBuffer & write_buf) const { saveToBuffer(write_buf); } + void saveFileProps(dtpb::IndexFilePropsV2 * pb_idx) const override { writer.saveFileProps(pb_idx); } + + dtpb::IndexFileKind kind() const override { return dtpb::IndexFileKind::VECTOR_INDEX; } + +private: + VectorIndexWriterInternal writer; }; // FIXME: The building process is still in-memory. Only the index file is saved to disk. -class VectorIndexWriterOnDisk : public VectorIndexWriterInternal +class VectorIndexWriterOnDisk : public LocalIndexWriterOnDisk { public: - explicit VectorIndexWriterOnDisk(std::string_view index_file, const TiDB::VectorIndexDefinitionPtr & definition_) - : VectorIndexWriterInternal(definition_) - , index_file(index_file) - {} - - static VectorIndexWriterOnDiskPtr create( + explicit VectorIndexWriterOnDisk( + IndexID index_id, std::string_view index_file, const TiDB::VectorIndexDefinitionPtr & definition) + : LocalIndexWriterOnDisk(index_id, index_file) + , writer(definition) + {} + + void saveToFile() const override; + + void addBlock(const IColumn & column, const ColumnVector * del_mark, ProceedCheckFn should_proceed) override { - return std::make_shared(index_file, definition); + writer.addBlock(column, del_mark, should_proceed); } - void finalize() const { saveToFile(index_file); } + void saveFileProps(dtpb::IndexFilePropsV2 * pb_idx) const override { writer.saveFileProps(pb_idx); } + + dtpb::IndexFileKind kind() const override { return dtpb::IndexFileKind::VECTOR_INDEX; } private: - const std::string index_file; + VectorIndexWriterInternal writer; }; } // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Segment.cpp b/dbms/src/Storages/DeltaMerge/Segment.cpp index cf88e802dad..569ccdae79d 100644 --- a/dbms/src/Storages/DeltaMerge/Segment.cpp +++ b/dbms/src/Storages/DeltaMerge/Segment.cpp @@ -26,8 +26,6 @@ #include #include #include -#include -#include #include #include #include @@ -41,6 +39,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -779,26 +780,19 @@ bool Segment::isDefinitelyEmpty(DMContext & dm_context, const SegmentSnapshotPtr } // The delta stream is empty. Let's then try to read from stable. - { - SkippableBlockInputStreams streams; - for (const auto & file : segment_snap->stable->getDMFiles()) - { - DMFileBlockInputStreamBuilder builder(dm_context.global_context); - auto stream = builder - .setRowsThreshold( - std::numeric_limits::max()) // TODO: May be we could have some better settings - .onlyReadOnePackEveryTime() - .build(file, *columns_to_read, read_ranges, dm_context.scan_context); - streams.push_back(stream); - } - - BlockInputStreamPtr stable_stream - = std::make_shared>(streams, dm_context.scan_context); - stable_stream = std::make_shared>(stable_stream, read_ranges, 0); - stable_stream->readPrefix(); + for (const auto & file : segment_snap->stable->getDMFiles()) + { + DMFileBlockInputStreamBuilder builder(dm_context.global_context); + auto stream = builder + .setRowsThreshold( + std::numeric_limits::max()) // TODO: May be we could have some better settings + .onlyReadOnePackEveryTime() + .build(file, *columns_to_read, read_ranges, dm_context.scan_context); + auto stream2 = std::make_shared>(stream, read_ranges, 0); + stream2->readPrefix(); while (true) { - Block block = stable_stream->read(); + Block block = stream2->read(); if (!block) break; if (block.rows() > 0) @@ -806,7 +800,7 @@ bool Segment::isDefinitelyEmpty(DMContext & dm_context, const SegmentSnapshotPtr // because we are not considering the delete range. return false; } - stable_stream->readSuffix(); + stream2->readSuffix(); } // We cannot read out anything from the delta stream and the stable stream, @@ -3242,7 +3236,7 @@ BitmapFilterPtr Segment::buildBitmapFilterStableOnly( getVersionColumnDefine(), getTagColumnDefine(), }; - BlockInputStreamPtr stream = segment_snap->stable->getInputStream( + BlockInputStreamPtr stream = segment_snap->stable->getInputStream( dm_context, columns_to_read, read_ranges, @@ -3253,8 +3247,7 @@ BitmapFilterPtr Segment::buildBitmapFilterStableOnly( new_pack_filter_results, /*is_fast_scan*/ false, /*enable_del_clean_read*/ false, - /*read_packs*/ {}, - /*need_row_id*/ true); + /*read_packs*/ {}); stream = std::make_shared>(stream, read_ranges, 0); const ColumnDefines read_columns{ getExtraHandleColumnDefine(is_common_handle), @@ -3287,13 +3280,12 @@ SkippableBlockInputStreamPtr Segment::getConcatSkippableBlockInputStream( size_t expected_block_size, ReadTag read_tag) { - static constexpr bool NeedRowID = false; // set `is_fast_scan` to true to try to enable clean read auto enable_handle_clean_read = !hasColumn(columns_to_read, MutSup::extra_handle_id); constexpr auto is_fast_scan = true; auto enable_del_clean_read = !hasColumn(columns_to_read, MutSup::version_col_id); - SkippableBlockInputStreamPtr stable_stream = segment_snap->stable->getInputStream( + auto stream = segment_snap->stable->getInputStream( dm_context, columns_to_read, read_ranges, @@ -3304,8 +3296,7 @@ SkippableBlockInputStreamPtr Segment::getConcatSkippableBlockInputStream( pack_filter_results, is_fast_scan, enable_del_clean_read, - /* read_packs */ {}, - NeedRowID); + /* read_packs */ {}); auto columns_to_read_ptr = std::make_shared(columns_to_read); @@ -3324,14 +3315,12 @@ SkippableBlockInputStreamPtr Segment::getConcatSkippableBlockInputStream( this->rowkey_range, read_tag); - auto stream = std::dynamic_pointer_cast>(stable_stream); - assert(stream != nullptr); stream->appendChild(persisted_files_stream, persisted_files->getRows()); stream->appendChild(mem_table_stream, memtable->getRows()); return stream; } -std::tuple Segment::getConcatVectorIndexBlockInputStream( +BlockInputStreamPtr Segment::getConcatVectorIndexBlockInputStream( BitmapFilterPtr bitmap_filter, const SegmentSnapshotPtr & segment_snap, const DMContext & dm_context, @@ -3343,17 +3332,31 @@ std::tuple Segment::getConcatVectorIndexBloc size_t expected_block_size, ReadTag read_tag) { - static constexpr bool NeedRowID = false; // set `is_fast_scan` to true to try to enable clean read auto enable_handle_clean_read = !hasColumn(columns_to_read, MutSup::extra_handle_id); constexpr auto is_fast_scan = true; auto enable_del_clean_read = !hasColumn(columns_to_read, MutSup::delmark_col_id); - SkippableBlockInputStreamPtr stable_stream = segment_snap->stable->tryGetInputStreamWithVectorIndex( + auto columns_to_read_ptr = std::make_shared(columns_to_read); + + auto memtable = segment_snap->delta->getMemTableSetSnapshot(); + auto persisted = segment_snap->delta->getPersistedFileSetSnapshot(); + + auto ctx = VectorIndexStreamCtx::create( + dm_context.global_context.getLightLocalIndexCache(), + dm_context.global_context.getHeavyLocalIndexCache(), + ann_query_info, + columns_to_read_ptr, + persisted->getDataProvider(), + dm_context, + read_tag); + + // The order in the stream: stable1, stable2, ..., persist1, persist2, ..., memtable1, memtable2, ... + + auto stream = segment_snap->stable->getInputStream( dm_context, columns_to_read, read_ranges, - ann_query_info, start_ts, expected_block_size, enable_handle_clean_read, @@ -3362,35 +3365,24 @@ std::tuple Segment::getConcatVectorIndexBloc is_fast_scan, enable_del_clean_read, /* read_packs */ {}, - NeedRowID, - bitmap_filter); - - auto columns_to_read_ptr = std::make_shared(columns_to_read); - - auto memtable = segment_snap->delta->getMemTableSetSnapshot(); - auto persisted_files = segment_snap->delta->getPersistedFileSetSnapshot(); - SkippableBlockInputStreamPtr mem_table_stream = std::make_shared( - dm_context, - memtable, - columns_to_read_ptr, - this->rowkey_range, - read_tag); - SkippableBlockInputStreamPtr persisted_files_stream = ColumnFileSetWithVectorIndexInputStream::tryBuild( - dm_context, - persisted_files, - columns_to_read_ptr, - this->rowkey_range, - persisted_files->getDataProvider(), - ann_query_info, - bitmap_filter, - segment_snap->stable->getDMFilesRows(), - read_tag); + [=](DMFileBlockInputStreamBuilder & builder) { builder.setVecIndexQuery(ctx); }); + + for (const auto & file : persisted->getColumnFiles()) + stream->appendChild(ColumnFileProvideVectorIndexInputStream::createOrFallback(ctx, file), file->getRows()); + for (const auto & file : memtable->getColumnFiles()) + stream->appendChild(ColumnFileProvideVectorIndexInputStream::createOrFallback(ctx, file), file->getRows()); + + auto stream2 = VectorIndexInputStream::create(ctx, bitmap_filter, stream); + // For vector search, there are more likely to return small blocks from different + // sub-streams. Squash blocks to reduce the number of blocks thus improve the + // performance of upper layer. + auto stream3 = std::make_shared( + stream2, + /*min_block_size_rows=*/expected_block_size, + /*min_block_size_bytes=*/0, + dm_context.tracing_id); - auto stream = std::dynamic_pointer_cast>(stable_stream); - assert(stream != nullptr); - stream->appendChild(persisted_files_stream, persisted_files->getRows()); - stream->appendChild(mem_table_stream, memtable->getRows()); - return ConcatVectorIndexBlockInputStream::build(bitmap_filter, stream, ann_query_info); + return stream3; } BlockInputStreamPtr Segment::getLateMaterializationStream( @@ -3553,12 +3545,11 @@ BlockInputStreamPtr Segment::getBitmapFilterInputStream( read_data_block_rows); } - SkippableBlockInputStreamPtr stream; + BlockInputStreamPtr stream; if (executor && executor->ann_query_info) { // For ANN query, try to use vector index to accelerate. - bool is_vector = false; - std::tie(stream, is_vector) = getConcatVectorIndexBlockInputStream( + return getConcatVectorIndexBlockInputStream( bitmap_filter, segment_snap, dm_context, @@ -3569,30 +3560,17 @@ BlockInputStreamPtr Segment::getBitmapFilterInputStream( start_ts, read_data_block_rows, ReadTag::Query); - if (is_vector) - { - // For vector search, there are more likely to return small blocks from different - // sub-streams. Squash blocks to reduce the number of blocks thus improve the - // performance of upper layer. - return std::make_shared( - stream, - /*min_block_size_rows=*/read_data_block_rows, - /*min_block_size_bytes=*/0, - dm_context.tracing_id); - } - } - else - { - stream = getConcatSkippableBlockInputStream( - segment_snap, - dm_context, - columns_to_read, - read_ranges, - pack_filter_results, - start_ts, - read_data_block_rows, - ReadTag::Query); } + + stream = getConcatSkippableBlockInputStream( + segment_snap, + dm_context, + columns_to_read, + read_ranges, + pack_filter_results, + start_ts, + read_data_block_rows, + ReadTag::Query); return std::make_shared(columns_to_read, stream, bitmap_filter); } diff --git a/dbms/src/Storages/DeltaMerge/Segment.h b/dbms/src/Storages/DeltaMerge/Segment.h index bb331e8cdef..c21ab89e1b9 100644 --- a/dbms/src/Storages/DeltaMerge/Segment.h +++ b/dbms/src/Storages/DeltaMerge/Segment.h @@ -751,8 +751,7 @@ class Segment const DMFilePackFilterResults & pack_filter_results, UInt64 start_ts, size_t expected_block_size); - // Returns - std::tuple getConcatVectorIndexBlockInputStream( + static BlockInputStreamPtr getConcatVectorIndexBlockInputStream( BitmapFilterPtr bitmap_filter, const SegmentSnapshotPtr & segment_snap, const DMContext & dm_context, diff --git a/dbms/src/Storages/DeltaMerge/SegmentReadTask.cpp b/dbms/src/Storages/DeltaMerge/SegmentReadTask.cpp index 10ed1b6326f..066228ccab6 100644 --- a/dbms/src/Storages/DeltaMerge/SegmentReadTask.cpp +++ b/dbms/src/Storages/DeltaMerge/SegmentReadTask.cpp @@ -132,7 +132,7 @@ SegmentReadTask::SegmentReadTask( remote_page_ids.emplace_back(tiny->getDataPageId()); remote_page_sizes.emplace_back(tiny->getDataPageSize()); ++count; - // Add vector index pages. + // Add local index pages. if (auto index_infos = tiny->getIndexInfos(); index_infos) { for (const auto & index_info : *index_infos) diff --git a/dbms/src/Storages/DeltaMerge/SkippableBlockInputStream.h b/dbms/src/Storages/DeltaMerge/SkippableBlockInputStream.h index 4634c30f707..1094fc585df 100644 --- a/dbms/src/Storages/DeltaMerge/SkippableBlockInputStream.h +++ b/dbms/src/Storages/DeltaMerge/SkippableBlockInputStream.h @@ -46,6 +46,7 @@ class SkippableBlockInputStream : public IBlockInputStream using SkippableBlockInputStreamPtr = std::shared_ptr; using SkippableBlockInputStreams = std::vector; +/// A SkippableBlockInputStream that always returns an empty block. class EmptySkippableBlockInputStream : public SkippableBlockInputStream { public: @@ -69,4 +70,43 @@ class EmptySkippableBlockInputStream : public SkippableBlockInputStream ColumnDefines read_columns; }; +template +class AsNopSkippableBlockInputStream; + +/// A SkippableBlockInputStream that does not support any skip operations. +class NopSkippableBlockInputStream : public SkippableBlockInputStream +{ +public: + /// Wraps any stream into a NopSkippableBlockInputStream. + /// Note: After wrapping, you cannot use dynamic_pointer_cast to get the original stream. Use children[0] instead. + template + static auto wrap(const std::shared_ptr & stream) + { + return std::make_shared>(stream); + } + +public: + bool getSkippedRows(size_t &) override { throw Exception("Not implemented", ErrorCodes::NOT_IMPLEMENTED); } + + size_t skipNextBlock() override { throw Exception("Not implemented", ErrorCodes::NOT_IMPLEMENTED); } + + Block readWithFilter(const IColumn::Filter &) override + { + throw Exception("Not implemented", ErrorCodes::NOT_IMPLEMENTED); + } +}; + +template +class AsNopSkippableBlockInputStream : public NopSkippableBlockInputStream +{ +public: + explicit AsNopSkippableBlockInputStream(const std::shared_ptr & stream_) { children.push_back(stream_); } + + String getName() const override { return "AsNopSkippable"; } + + Block getHeader() const override { return children[0]->getHeader(); } + + Block read() override { return children[0]->read(); } +}; + } // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/StableValueSpace.cpp b/dbms/src/Storages/DeltaMerge/StableValueSpace.cpp index ca1aa9f38c0..b6e170ea909 100644 --- a/dbms/src/Storages/DeltaMerge/StableValueSpace.cpp +++ b/dbms/src/Storages/DeltaMerge/StableValueSpace.cpp @@ -441,7 +441,8 @@ void StableValueSpace::drop(const FileProviderPtr & file_provider) } } -SkippableBlockInputStreamPtr StableValueSpace::Snapshot::getInputStream( +template +ConcatSkippableBlockInputStreamPtr StableValueSpace::Snapshot::getInputStream( const DMContext & dm_context, const ColumnDefines & read_columns, const RowKeyRanges & rowkey_ranges, @@ -453,7 +454,7 @@ SkippableBlockInputStreamPtr StableValueSpace::Snapshot::getInputStream( bool is_fast_scan, bool enable_del_clean_read, const std::vector & read_packs, - bool need_row_id) + std::function additional_builder_opt) { LOG_DEBUG( log, @@ -479,31 +480,22 @@ SkippableBlockInputStreamPtr StableValueSpace::Snapshot::getInputStream( .setRowsThreshold(expected_block_size) .setReadPacks(read_packs.size() > i ? read_packs[i] : nullptr) .setReadTag(read_tag); - + if (additional_builder_opt) + additional_builder_opt(builder); streams.push_back(builder.build(stable->files[i], read_columns, rowkey_ranges, dm_context.scan_context)); rows.push_back(stable->files[i]->getRows()); } - if (need_row_id) - { - return std::make_shared>( - streams, - std::move(rows), - dm_context.scan_context); - } - else - { - return std::make_shared>( - streams, - std::move(rows), - dm_context.scan_context); - } + + return ConcatSkippableBlockInputStream::create( + std::move(streams), + std::move(rows), + dm_context.scan_context); } -SkippableBlockInputStreamPtr StableValueSpace::Snapshot::tryGetInputStreamWithVectorIndex( +template ConcatSkippableBlockInputStreamPtr StableValueSpace::Snapshot::getInputStream( const DMContext & dm_context, const ColumnDefines & read_columns, const RowKeyRanges & rowkey_ranges, - const ANNQueryInfoPtr & ann_query_info, UInt64 max_data_version, size_t expected_block_size, bool enable_handle_clean_read, @@ -512,65 +504,21 @@ SkippableBlockInputStreamPtr StableValueSpace::Snapshot::tryGetInputStreamWithVe bool is_fast_scan, bool enable_del_clean_read, const std::vector & read_packs, - bool need_row_id, - BitmapFilterPtr bitmap_filter) -{ - LOG_DEBUG( - log, - "StableVS tryGetInputStreamWithVectorIndex" - " start_ts={} enable_handle_clean_read={} is_fast_mode={} enable_del_clean_read={}", - max_data_version, - enable_handle_clean_read, - is_fast_scan, - enable_del_clean_read); - SkippableBlockInputStreams streams; - std::vector rows; - streams.reserve(stable->files.size()); - rows.reserve(stable->files.size()); - - size_t last_rows = 0; + std::function additional_builder_opt); - for (size_t i = 0; i < stable->files.size(); ++i) - { - DMFileBlockInputStreamBuilder builder(dm_context.global_context); - builder.enableCleanRead(enable_handle_clean_read, is_fast_scan, enable_del_clean_read, max_data_version) - .enableColumnCacheLongTerm(dm_context.pk_col_id) - .setAnnQureyInfo(ann_query_info) - .setDMFilePackFilterResult(pack_filter_results.size() > i ? pack_filter_results[i] : nullptr) - .setColumnCache(column_caches[i]) - .setTracingID(dm_context.tracing_id) - .setRowsThreshold(expected_block_size) - .setReadPacks(read_packs.size() > i ? read_packs[i] : nullptr) - .setReadTag(read_tag); - if (bitmap_filter) - { - builder.setBitmapFilter( - BitmapFilterView(bitmap_filter, last_rows, last_rows + stable->files[i]->getRows())); - last_rows += stable->files[i]->getRows(); - } - - streams.push_back(builder.build( // - stable->files[i], - read_columns, - rowkey_ranges, - dm_context.scan_context)); - rows.push_back(stable->files[i]->getRows()); - } - if (need_row_id) - { - return std::make_shared>( - streams, - std::move(rows), - dm_context.scan_context); - } - else - { - return std::make_shared>( - streams, - std::move(rows), - dm_context.scan_context); - } -} +template ConcatSkippableBlockInputStreamPtr StableValueSpace::Snapshot::getInputStream( + const DMContext & dm_context, + const ColumnDefines & read_columns, + const RowKeyRanges & rowkey_ranges, + UInt64 max_data_version, + size_t expected_block_size, + bool enable_handle_clean_read, + ReadTag read_tag, + const DMFilePackFilterResults & pack_filter_results, + bool is_fast_scan, + bool enable_del_clean_read, + const std::vector & read_packs, + std::function additional_builder_opt); RowsAndBytes StableValueSpace::Snapshot::getApproxRowsAndBytes(const DMContext & dm_context, const RowKeyRange & range) const diff --git a/dbms/src/Storages/DeltaMerge/StableValueSpace.h b/dbms/src/Storages/DeltaMerge/StableValueSpace.h index 8846c06711b..679367dec8a 100644 --- a/dbms/src/Storages/DeltaMerge/StableValueSpace.h +++ b/dbms/src/Storages/DeltaMerge/StableValueSpace.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -38,6 +39,8 @@ struct WriteBatches; class StableValueSpace; using StableValueSpacePtr = std::shared_ptr; +class DMFileBlockInputStreamBuilder; + class StableValueSpace : public std::enable_shared_from_this { public: @@ -224,7 +227,8 @@ class StableValueSpace : public std::enable_shared_from_this } } - SkippableBlockInputStreamPtr getInputStream( + template + ConcatSkippableBlockInputStreamPtr getInputStream( const DMContext & dm_context, // const ColumnDefines & read_columns, const RowKeyRanges & rowkey_ranges, @@ -236,23 +240,7 @@ class StableValueSpace : public std::enable_shared_from_this bool is_fast_scan = false, bool enable_del_clean_read = false, const std::vector & read_packs = {}, - bool need_row_id = false); - - SkippableBlockInputStreamPtr tryGetInputStreamWithVectorIndex( - const DMContext & dm_context, - const ColumnDefines & read_columns, - const RowKeyRanges & rowkey_ranges, - const ANNQueryInfoPtr & ann_query_info, - UInt64 max_data_version, - size_t expected_block_size, - bool enable_handle_clean_read, - ReadTag read_tag, - const DMFilePackFilterResults & pack_filter_results, - bool is_fast_scan = false, - bool enable_del_clean_read = false, - const std::vector & read_packs = {}, - bool need_row_id = false, - BitmapFilterPtr bitmap_filter = nullptr); + std::function additional_builder_opt = nullptr); RowsAndBytes getApproxRowsAndBytes(const DMContext & dm_context, const RowKeyRange & range) const; diff --git a/dbms/src/Storages/DeltaMerge/VectorIndexBlockInputStream.h b/dbms/src/Storages/DeltaMerge/VectorIndexBlockInputStream.h deleted file mode 100644 index 68ed953b6c6..00000000000 --- a/dbms/src/Storages/DeltaMerge/VectorIndexBlockInputStream.h +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - - -namespace DB::DM -{ - -class VectorIndexBlockInputStream : public SkippableBlockInputStream -{ -public: - bool getSkippedRows(size_t &) override - { - RUNTIME_CHECK_MSG(false, "DMFileWithVectorIndexBlockInputStream does not support getSkippedRows"); - } - - size_t skipNextBlock() override - { - RUNTIME_CHECK_MSG(false, "DMFileWithVectorIndexBlockInputStream does not support skipNextBlock"); - } - - Block readWithFilter(const IColumn::Filter &) override - { - // We don't support the normal late materialization, because - // we are already doing it. - RUNTIME_CHECK_MSG(false, "DMFileWithVectorIndexBlockInputStream does not support late materialization"); - } - - // Load vector index and search results. - // Return the rowids of the selected rows. - virtual std::vector load() = 0; - - // Set the real selected rows offset (local offset). This is used to update Packs/ColumnFiles to read. - // For example, DMFile should update DMFilePackFilter, only packs with selected rows will be read. - virtual void setSelectedRows(const std::span & selected_rows) = 0; -}; - -using VectorIndexBlockInputStreamPtr = std::shared_ptr; - -} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/VersionChain/ColumnView.h b/dbms/src/Storages/DeltaMerge/VersionChain/ColumnView.h new file mode 100644 index 00000000000..0220432f018 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/VersionChain/ColumnView.h @@ -0,0 +1,158 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace DB::DM +{ +// `ColumnView` is a class that provides unified access to both Int64 handles and String handles. +template +class ColumnView +{ + static_assert(false, "Only support Int64 and String"); +}; + +template <> +class ColumnView +{ +public: + ColumnView(const IColumn & col) + : data(toColumnVectorData(col)) + {} + + auto begin() const { return data.begin(); } + + auto end() const { return data.end(); } + + Int64 operator[](size_t index) const + { + assert(index < data.size()); + return data[index]; + } + + size_t size() const { return data.size(); } + +private: + const PaddedPODArray & data; +}; + +template <> +class ColumnView +{ +public: + ColumnView(const IColumn & col) + : offsets(typeid_cast(col).getOffsets()) + , chars(typeid_cast(col).getChars()) + {} + + class Iterator + { + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = std::string_view; + using difference_type = std::ptrdiff_t; + + Iterator(const IColumn::Offsets & offsets, const ColumnString::Chars_t & chars, size_t pos) + : pos(pos) + , offsets(&offsets) + , chars(&chars) + {} + + value_type operator*() const + { + assert((*offsets)[-1] == 0); + const auto off = (*offsets)[pos - 1]; + const auto size = (*offsets)[pos] - (*offsets)[pos - 1] - 1; + return std::string_view(reinterpret_cast(chars->data() + off), size); + } + + Iterator operator+(difference_type n) { return Iterator{*offsets, *chars, pos + n}; } + + Iterator operator-(difference_type n) { return Iterator{*offsets, *chars, pos - n}; } + + difference_type operator-(const Iterator & other) const { return pos - other.pos; } + + Iterator & operator++() + { + ++pos; + return *this; + } + + Iterator & operator--() + { + --pos; + return *this; + } + + Iterator operator++(int) + { + Iterator tmp = *this; + ++pos; + return tmp; + } + + Iterator operator--(int) + { + Iterator tmp = *this; + --pos; + return tmp; + } + + Iterator & operator+=(difference_type n) + { + pos += n; + return *this; + } + + Iterator & operator-=(difference_type n) + { + pos -= n; + return *this; + } + + // Perform a lexicographic comparison of elements. + // Assume `this->offsets == other.offsets && this->chars == other.chars`, + // so it equal to `this->pos <=> other.pos`. + auto operator<=>(const Iterator & other) const = default; + + private: + size_t pos = 0; + const IColumn::Offsets * offsets; // Using pointer for operator assignment + const ColumnString::Chars_t * chars; + }; + + auto begin() const { return Iterator(offsets, chars, 0); } + + auto end() const { return Iterator(offsets, chars, offsets.size()); } + + std::string_view operator[](size_t index) const + { + assert(index < offsets.size()); + const auto off = offsets[index - 1]; + const auto size = offsets[index] - offsets[index - 1] - 1; + return std::string_view(reinterpret_cast(chars.data() + off), size); + } + + size_t size() const { return offsets.size(); } + +private: + const IColumn::Offsets & offsets; + const ColumnString::Chars_t & chars; +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/dtpb/index_file.proto b/dbms/src/Storages/DeltaMerge/dtpb/index_file.proto index 4ec2b1539ae..ad8fe3c0c39 100644 --- a/dbms/src/Storages/DeltaMerge/dtpb/index_file.proto +++ b/dbms/src/Storages/DeltaMerge/dtpb/index_file.proto @@ -19,6 +19,7 @@ package dtpb; enum IndexFileKind { INVALID = 0; VECTOR_INDEX = 1; + INVERTED_INDEX = 2; } message IndexFilePropsV2 { @@ -28,6 +29,7 @@ message IndexFilePropsV2 { oneof prop { IndexFilePropsV2Vector vector_index = 31; + IndexFilePropsInverted inverted_index = 32; } } @@ -42,3 +44,7 @@ message IndexFilePropsV2Vector { optional string distance_metric = 2; // The value is tipb.VectorDistanceMetric optional uint64 dimensions = 3; } + +message IndexFilePropsInverted { + optional uint32 uncompressed_size = 1; +} diff --git a/dbms/src/Storages/DeltaMerge/tests/DMTestEnv.h b/dbms/src/Storages/DeltaMerge/tests/DMTestEnv.h index fadeacd785e..6b737c8fe82 100644 --- a/dbms/src/Storages/DeltaMerge/tests/DMTestEnv.h +++ b/dbms/src/Storages/DeltaMerge/tests/DMTestEnv.h @@ -31,11 +31,7 @@ #include -namespace DB -{ -namespace DM -{ -namespace tests +namespace DB::DM::tests { #define GET_REGION_RANGE(start, end, table_id) \ RowKeyRange::fromHandleRange(::DB::DM::HandleRange((start), (end))).toRegionRange((table_id)) @@ -112,6 +108,14 @@ inline String genMockCommonHandle(Int64 value, size_t rowkey_column_size) return ss.releaseStr(); } +inline Int64 decodeMockCommonHandle(const String & s) +{ + size_t cursor = 0; + auto flag = ::DB::DecodeUInt(cursor, s); + RUNTIME_CHECK(flag == static_cast(TiDB::CodecFlagInt), flag); + return ::DB::DecodeInt64(cursor, s); +} + class DMTestEnv { public: @@ -305,36 +309,32 @@ class DMTestEnv size_t rowkey_column_size = 1, bool with_internal_columns = true, bool is_deleted = false, - bool with_nullable_uint64 = false) + bool with_nullable_uint64 = false, + bool including_right_boundary = false) // [beg, end) or [beg, end] { Block block; - const size_t num_rows = (end - beg); + const size_t num_rows = (end - beg) + including_right_boundary; + std::vector handles(num_rows); + std::iota(handles.begin(), handles.end(), beg); + if (reversed) + std::reverse(handles.begin(), handles.end()); if (is_common_handle) { - // common_pk_col Strings values; - for (size_t i = 0; i < num_rows; i++) - { - Int64 value = reversed ? end - 1 - i : beg + i; - values.emplace_back(genMockCommonHandle(value, rowkey_column_size)); - } + for (Int64 h : handles) + values.emplace_back(genMockCommonHandle(h, rowkey_column_size)); block.insert(DB::tests::createColumn(std::move(values), pk_name_, pk_col_id)); } else { // int-like pk_col - block.insert(ColumnWithTypeAndName{ - DB::tests::makeColumn(pk_type, createNumbers(beg, end, reversed)), - pk_type, - pk_name_, - pk_col_id}); + block.insert( + ColumnWithTypeAndName{DB::tests::makeColumn(pk_type, handles), pk_type, pk_name_, pk_col_id}); // add extra column if need if (pk_col_id != MutSup::extra_handle_id) { block.insert(ColumnWithTypeAndName{ - DB::tests::makeColumn( - MutSup::getExtraHandleColumnIntType(), - createNumbers(beg, end, reversed)), + DB::tests::makeColumn(MutSup::getExtraHandleColumnIntType(), handles), MutSup::getExtraHandleColumnIntType(), MutSup::extra_handle_column_name, MutSup::extra_handle_id}); @@ -568,7 +568,4 @@ class DMTestEnv return num++; } }; - -} // namespace tests -} // namespace DM -} // namespace DB +} // namespace DB::DM::tests diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_column_cache.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_column_cache.cpp new file mode 100644 index 00000000000..0076b468e62 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_column_cache.cpp @@ -0,0 +1,121 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +namespace DB::DM::tests +{ + +TEST(ColumnCacheTest, BasicOperations) +try +{ + auto cache = ColumnCache(); + + // Test put and get + ColumnID col_id = 1; + auto data = genSequence("[0, 5)"); + auto col = ::DB::tests::createColumn(data, "", col_id).column; + cache.tryPutColumn(0, 1, col, 0, 5); + + ASSERT_TRUE(cache.getReadStrategy(0, 1, col_id)[0].second == ColumnCache::Strategy::Memory); + + { + // Get the column that exactly matches the cache + auto get_col = cache.getColumn(0, 1, 5, col_id); + ASSERT_EQ(get_col->size(), 5); + } + { + // Get a part of the cache + auto get_col = cache.getColumn(0, 1, 3, col_id); + ASSERT_EQ(get_col->size(), 3); + } + + // Test delete + cache.delColumn(1, 1); + ASSERT_TRUE(cache.getReadStrategy(0, 1, col_id)[0].second == ColumnCache::Strategy::Disk); +} +CATCH + +TEST(ColumnCacheTest, RangeStrategy) +try +{ + auto cache = ColumnCache(); + + ColumnID col_id = 1; + // Prepare test data + { + auto data = genSequence("[0, 10)"); + auto col = ::DB::tests::createColumn(data, "", col_id).column; + // pack0, pack1 share the same ColumnPtr with different rows_offset + cache.tryPutColumn(0, col_id, col, 0, 5); + cache.tryPutColumn(1, col_id, col, 5, 5); + } + { + auto data = genSequence("[10, 20)"); + auto col = ::DB::tests::createColumn(data, "", col_id).column; + // pack2, pack3 share the same ColumnPtr with different rows_offset + cache.tryPutColumn(2, col_id, col, 0, 5); + cache.tryPutColumn(3, col_id, col, 5, 5); + } + + // Test continuous range + { + // read pack0, pack1 + auto strategies = cache.getReadStrategy(0, 2, col_id); + ASSERT_EQ(strategies.size(), 1); + ASSERT_EQ(strategies[0].second, ColumnCache::Strategy::Memory); + auto get_col = cache.getColumn(0, 2, 10, col_id); + auto data = genSequence("[0, 10)"); + const auto & actual_data = toColumnVectorData(get_col); + ASSERT_TRUE(sequenceEqual(data, actual_data)); + } + { + // read pack1, pack2 + auto strategies = cache.getReadStrategy(1, 2, col_id); + ASSERT_EQ(strategies.size(), 1); + ASSERT_EQ(strategies[0].second, ColumnCache::Strategy::Memory); + auto get_col = cache.getColumn(1, 3, 10, col_id); + auto data = genSequence("[5, 15)"); + const auto & actual_data = toColumnVectorData(get_col); + ASSERT_TRUE(sequenceEqual(data, actual_data)); + } + + // Test mixed range + cache.delColumn(col_id, 1); + auto strategies = cache.getReadStrategy(0, 2, col_id); + ASSERT_EQ(strategies.size(), 2); + ASSERT_EQ(strategies[0].second, ColumnCache::Strategy::Disk); + ASSERT_EQ(strategies[1].second, ColumnCache::Strategy::Memory); +} +CATCH + +TEST(ColumnCacheTest, CleanReadStrategy) +try +{ + std::vector clean_packs = {0, 2, 4}; + auto strategies = ColumnCache::getCleanReadStrategy(0, 5, clean_packs); + + ASSERT_EQ(strategies.size(), 5); + ASSERT_EQ(strategies[0].second, ColumnCache::Strategy::Memory); + ASSERT_EQ(strategies[1].second, ColumnCache::Strategy::Disk); + ASSERT_EQ(strategies[2].second, ColumnCache::Strategy::Memory); + ASSERT_EQ(strategies[3].second, ColumnCache::Strategy::Disk); + ASSERT_EQ(strategies[4].second, ColumnCache::Strategy::Memory); +} +CATCH + +} // namespace DB::DM::tests diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_delta_merge_store.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_delta_merge_store.cpp index f499edde1f5..f9f064c4081 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_delta_merge_store.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_delta_merge_store.cpp @@ -2943,7 +2943,12 @@ try auto table_column_defines = DMTestEnv::getDefaultColumns(); table_column_defines->emplace_back(legacy_str_cd); dropDataOnDisk(getTemporaryPath()); + + auto tmp = STORAGE_FORMAT_CURRENT; + setStorageFormat(7); // set to legacy format temporary. store = reload(table_column_defines); + setStorageFormat(tmp.identifier); // Reset to current format. + auto block = createBlock(legacy_str_cd, 0, 128); store->write(*db_context, db_context->getSettingsRef(), block); @@ -2964,11 +2969,26 @@ try // Mock that after restart, the data type has been changed to new serialize. But still can read old // serialized format data. auto table_column_defines = DMTestEnv::getDefaultColumns(); - table_column_defines->emplace_back(str_cd); + table_column_defines->emplace_back(legacy_str_cd); store = reload(table_column_defines); + + // Verify that the string column with old data type name will be automatically + // converted into new serialized format. + bool legacy_str_is_converted = false; + auto store_cds = *store->getStoreColumns(); + for (const auto & cd : store_cds) + { + if (cd.id == legacy_str_cd.id) + { + ASSERT_EQ(cd.type->getName(), str_cd.type->getName()); + legacy_str_is_converted = true; + } + } + ASSERT_TRUE(legacy_str_is_converted); } { + // Still can read old serialized format data auto in = store->read( *db_context, db_context->getSettingsRef(), diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_delta_merge_store_fast_add_peer.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_delta_merge_store_fast_add_peer.cpp index 3e2a71a1bb5..95e64b85a56 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_delta_merge_store_fast_add_peer.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_delta_merge_store_fast_add_peer.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -32,7 +33,6 @@ #include #include #include -#include #include #include #include @@ -390,9 +390,7 @@ try db_context->getSettingsRef(), RowKeyRange::newAll(false, 1), checkpoint_info); - auto start = RecordKVFormat::genKey(table_id, 0); - auto end = RecordKVFormat::genKey(table_id, 10); - RegionPtr dummy_region = tests::makeRegion(checkpoint_info->region_id, start, end, nullptr); + RegionPtr dummy_region = RegionBench::makeRegionForTable(checkpoint_info->region_id, table_id, 0, 10, nullptr); store->ingestSegmentsFromCheckpointInfo( *db_context, db_context->getSettingsRef(), @@ -523,9 +521,7 @@ try db_context->getSettingsRef(), RowKeyRange::fromHandleRange(HandleRange(0, num_rows_write / 2)), checkpoint_info); - auto start = RecordKVFormat::genKey(table_id, 0); - auto end = RecordKVFormat::genKey(table_id, 10); - RegionPtr dummy_region = tests::makeRegion(checkpoint_info->region_id, start, end, nullptr); + RegionPtr dummy_region = RegionBench::makeRegionForTable(checkpoint_info->region_id, table_id, 0, 10, nullptr); store->ingestSegmentsFromCheckpointInfo( *db_context, db_context->getSettingsRef(), @@ -548,9 +544,7 @@ try db_context->getSettingsRef(), RowKeyRange::fromHandleRange(HandleRange(num_rows_write / 2, num_rows_write)), checkpoint_info); - auto start = RecordKVFormat::genKey(table_id, 0); - auto end = RecordKVFormat::genKey(table_id, 10); - RegionPtr dummy_region = tests::makeRegion(checkpoint_info->region_id, start, end, nullptr); + RegionPtr dummy_region = RegionBench::makeRegionForTable(checkpoint_info->region_id, table_id, 0, 10, nullptr); store->ingestSegmentsFromCheckpointInfo( *db_context, db_context->getSettingsRef(), diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_delta_merge_store_vector_index.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_delta_merge_store_vector_index.cpp index 5bda8d3300e..3ee5e4f3200 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_delta_merge_store_vector_index.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_delta_merge_store_vector_index.cpp @@ -151,6 +151,21 @@ class DeltaMergeStoreVectorTest } } + void triggerFlushCache() const + { + std::vector all_segments; + { + std::shared_lock lock(store->read_write_mutex); + for (const auto & [_, segment] : store->id_to_segment) + all_segments.push_back(segment); + } + auto dm_context = store->newDMContext(*db_context, db_context->getSettingsRef()); + for (const auto & segment : all_segments) + { + ASSERT_TRUE(segment->flushCache(*dm_context)); + } + } + void triggerCompactDelta() const { std::vector all_segments; @@ -245,24 +260,16 @@ try read(range, EMPTY_FILTER, colVecFloat32("[64, 256)", vec_column_name, vec_column_id)); } - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_column_id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - // read with ANN query { - ann_query_info->set_top_k(2); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({127.5})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {127.5}, .top_k = 2}); auto filter = std::make_shared(ann_query_info); read(range, filter, createVecFloat32Column({{127.0}, {128.0}})); } // read with ANN query { - ann_query_info->set_top_k(2); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({72.1})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {72.1}, .top_k = 2}); auto filter = std::make_shared(ann_query_info); read(range, filter, createVecFloat32Column({{72.0}, {73.0}})); } @@ -295,24 +302,16 @@ try read(range, EMPTY_FILTER, colVecFloat32("[0, 256)", vec_column_name, vec_column_id)); } - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_column_id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - // read with ANN query { - ann_query_info->set_top_k(2); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({72.1})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {72.1}, .top_k = 2}); auto filter = std::make_shared(ann_query_info); read(range, filter, createVecFloat32Column({{72.0}, {73.0}})); } // read with ANN query { - ann_query_info->set_top_k(2); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({127.5})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {127.5}, .top_k = 2}); auto filter = std::make_shared(ann_query_info); read(range, filter, createVecFloat32Column({{127.0}, {128.0}})); } @@ -352,15 +351,9 @@ try read(range, EMPTY_FILTER, colVecFloat32("[0, 130)", vec_column_name, vec_column_id)); } - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_column_id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({72.0})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {72.0}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); // [0, 128) with vector index return 72.0, [128, 130) without vector index return all. read(range, filter, createVecFloat32Column({{72.0}, {128.0}, {129.0}})); @@ -368,9 +361,7 @@ try // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({72.1})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {72.1}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); // [0, 128) with vector index return 72.0, [128, 130) without vector index return all. read(range, filter, createVecFloat32Column({{72.0}, {128.0}, {129.0}})); @@ -411,15 +402,9 @@ try read(range, EMPTY_FILTER, colVecFloat32("[0, 4)", vec_column_name, vec_column_id)); } - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_column_id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {1.0}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); // [0, 4) without vector index return all. read(range, filter, createVecFloat32Column({{0.0}, {1.0}, {2.0}, {3.0}})); @@ -427,9 +412,7 @@ try // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.1})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {1.1}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); // [0, 4) without vector index return all. read(range, filter, createVecFloat32Column({{0.0}, {1.0}, {2.0}, {3.0}})); @@ -494,27 +477,17 @@ try colVecFloat32(fmt::format("[0, {})", num_rows_write), vec_column_name, vec_column_id)); } - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_column_id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({2.0})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {2.0}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); - read(left_segment_range, filter, createVecFloat32Column({{2.0}})); } // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({222.1})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {222.1}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); - read(left_segment_range, filter, createVecFloat32Column({{127.0}})); } @@ -533,21 +506,15 @@ try // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({2.0})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {2.0}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); - read(range, filter, createVecFloat32Column({{2.0}})); } // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({122.1})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {122.1}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); - read(range, filter, createVecFloat32Column({{122.0}})); } } @@ -622,27 +589,17 @@ try colVecFloat32(fmt::format("[0, {})", num_rows_write), vec_column_name, vec_column_id)); } - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_column_id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({2.0})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {2.0}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); - read(left_segment_range, filter, createVecFloat32Column({{2.0}})); } // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({222.1})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {222.1}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); - read(left_segment_range, filter, createVecFloat32Column({{127.0}})); } @@ -661,21 +618,15 @@ try // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({2.0})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {2.0}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); - read(range, filter, createVecFloat32Column({{2.0}})); } // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({122.1})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {122.1}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); - read(range, filter, createVecFloat32Column({{122.0}})); } } @@ -745,27 +696,17 @@ try read(range, EMPTY_FILTER, colVecFloat32("[0, 128)", vec_column_name, vec_column_id)); } - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_column_id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({2.0})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {2.0}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); - read(range, filter, createVecFloat32Column({{2.0}})); } // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({2.1})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {2.1}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); - read(range, filter, createVecFloat32Column({{2.0}})); } } @@ -833,27 +774,17 @@ try read(range, EMPTY_FILTER, colVecFloat32("[0, 256)", vec_column_name, vec_column_id)); } - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_column_id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({2.0})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {2.0}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); - read(range, filter, createVecFloat32Column({{2.0}})); } // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({222.1})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {222.1}, .top_k = 1}); auto filter = std::make_shared(ann_query_info); - read(range, filter, createVecFloat32Column({{222.0}})); } } @@ -878,6 +809,9 @@ try // write [128, 256) to store write(num_rows_write, num_rows_write * 2); + // trigger FlushCache for all segments + triggerFlushCache(); + { // Add vecotr index TiDB::TableInfo new_table_info_with_vector_index; @@ -904,9 +838,6 @@ try ASSERT_EQ(store->local_index_infos->size(), 1); } - // trigger FlushCache for all segments - triggerFlushCacheAndEnsureDeltaLocalIndex(); - // check delta index has built for all segments waitDeltaIndexReady(); // check stable index has built for all segments @@ -919,16 +850,9 @@ try read(range, EMPTY_FILTER, colVecFloat32("[0, 256)", vec_column_name, vec_column_id)); } - auto ann_query_info = std::make_shared(); - ann_query_info->set_index_id(2); - ann_query_info->set_column_id(vec_column_id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({2.0})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {2.0}, .top_k = 1, .index_id = 2}); auto filter = std::make_shared(ann_query_info); read(range, filter, createVecFloat32Column({{2.0}})); @@ -936,9 +860,7 @@ try // read with ANN query { - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({222.1})); - + const auto ann_query_info = annQueryInfoTopK({.vec = {222.1}, .top_k = 1, .index_id = 2}); auto filter = std::make_shared(ann_query_info); read(range, filter, createVecFloat32Column({{222.0}})); @@ -958,6 +880,121 @@ try } CATCH +TEST_F(DeltaMergeStoreVectorTest, DDLAddMultipleVectorIndex) +try +{ + { + auto indexes = std::make_shared(); + store = reload(indexes); + ASSERT_EQ(store->getLocalIndexInfosSnapshot(), nullptr); + } + + const size_t num_rows_write = 128; + + // write [0, 128) to store + write(0, num_rows_write); + // trigger mergeDelta for all segments + triggerMergeDelta(); + + // write [128, 256) to store + write(num_rows_write, num_rows_write * 2); + + // trigger FlushCache for all segments + triggerFlushCache(); + + auto add_vector_index = [&](std::vector index_id, std::vector metrics) { + TiDB::TableInfo new_table_info_with_vector_index; + TiDB::ColumnInfo column_info; + column_info.name = VectorIndexTestUtils::vec_column_name; + column_info.id = VectorIndexTestUtils::vec_column_id; + new_table_info_with_vector_index.columns.emplace_back(column_info); + TiDB::IndexColumnInfo index_col_info; + index_col_info.name = VectorIndexTestUtils::vec_column_name; + index_col_info.offset = 0; + for (size_t i = 0; i < index_id.size(); ++i) + { + TiDB::IndexInfo index; + index.id = index_id[i]; + index.idx_cols.push_back(index_col_info); + index.vector_index = TiDB::VectorIndexDefinitionPtr(new TiDB::VectorIndexDefinition{ + .kind = tipb::VectorIndexKind::HNSW, + .dimension = 1, + .distance_metric = metrics[i], + }); + new_table_info_with_vector_index.index_infos.emplace_back(index); + } + // apply local index change, should + // - create the local index + // - generate the background tasks for building index on stable and delta + store->applyLocalIndexChange(new_table_info_with_vector_index); + ASSERT_EQ(store->local_index_infos->size(), index_id.size()); + + // check delta index has built for all segments + waitDeltaIndexReady(); + // check stable index has built for all segments + waitStableLocalIndexReady(); + }; + + const auto range = RowKeyRange::newAll(store->is_common_handle, store->rowkey_column_size); + auto query = [&](IndexID index_id, + tipb::VectorDistanceMetric metric, + const InferredDataVector & result_1, + const InferredDataVector & result_2) { + // read from store + { + read(range, EMPTY_FILTER, colVecFloat32("[0, 256)", vec_column_name, vec_column_id)); + } + + // read with ANN query + { + const auto ann_query_info = annQueryInfoTopK({ + .vec = {2.0}, + .top_k = 1, + .index_id = index_id, + .distance_metric = metric, + }); + auto filter = std::make_shared(ann_query_info); + + read(range, filter, createVecFloat32Column(result_1)); + } + + // read with ANN query + { + const auto ann_query_info = annQueryInfoTopK({ + .vec = {222.1}, + .top_k = 1, + .index_id = index_id, + .distance_metric = metric, + }); + auto filter = std::make_shared(ann_query_info); + + read(range, filter, createVecFloat32Column(result_2)); + } + }; + + // Add COSINE vector index + add_vector_index({1}, {tipb::VectorDistanceMetric::COSINE}); + query(1, tipb::VectorDistanceMetric::COSINE, {{129.0}}, {{129.0}}); + + // Add L2 vector index + add_vector_index({1, 2}, {tipb::VectorDistanceMetric::COSINE, tipb::VectorDistanceMetric::L2}); + query(1, tipb::VectorDistanceMetric::COSINE, {{129.0}}, {{129.0}}); + query(2, tipb::VectorDistanceMetric::L2, {{2.0}}, {{222.0}}); + + { + // vector index is dropped + TiDB::TableInfo new_table_info_with_vector_index; + TiDB::ColumnInfo column_info; + column_info.name = VectorIndexTestUtils::vec_column_name; + column_info.id = VectorIndexTestUtils::vec_column_id; + new_table_info_with_vector_index.columns.emplace_back(column_info); + // apply local index change, shuold drop the local index + store->applyLocalIndexChange(new_table_info_with_vector_index); + ASSERT_EQ(store->local_index_infos->size(), 0); + } +} +CATCH + TEST_F(DeltaMergeStoreVectorTest, DDLAddVectorIndexErrorMemoryExceed) try { diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp index 7bfee0b8e8d..32b37c39d5e 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -62,6 +63,151 @@ extern const char file_cache_fg_download_fail[]; namespace DB::DM::tests { +TEST(VectorIndexInputStream, NormalStream) +try +{ + // When a normal stream is provided in the VectorIndexInputStream, BitmapFilter should be applied. + + auto make_block_stream = [] { + return NopSkippableBlockInputStream::wrap(std::make_shared(Block{createColumns({ + createColumn({7, 4, 7, 0, 1, 2, 3}, "a"), + })})); + }; + + // VectorIndexInputStream does not need this information, but ctx needs at least a correct vec column. + auto ann = std::make_shared(); + ann->set_column_id(1); + auto ctx = VectorIndexStreamCtx::createForStableOnlyTests( + ann, + std::make_shared( + ColumnDefines{ColumnDefine{1, "vec", tests::typeFromString("Array(Float32)")}})); + + auto stream = VectorIndexTestUtils::wrapVectorStream( // + ctx, + make_block_stream(), + std::make_shared(7, true)); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + {"a"}, + createColumns({ + createColumn({7, 4, 7, 0, 1, 2, 3}), + })); + + stream = VectorIndexTestUtils::wrapVectorStream(ctx, make_block_stream(), std::make_shared(7, false)); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + {"a"}, + createColumns({ + createColumn({}), + })); + + auto filter = std::make_shared(7, false); + filter->set(4, 1); + stream = VectorIndexTestUtils::wrapVectorStream(ctx, make_block_stream(), filter); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + {"a"}, + createColumns({ + createColumn({1}), + })); + + filter = std::make_shared(7, false); + filter->set(4, 1); + filter->set(0, 1); + stream = VectorIndexTestUtils::wrapVectorStream(ctx, make_block_stream(), filter); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + {"a"}, + createColumns({ + createColumn({7, 1}), + })); +} +CATCH + +TEST(VectorIndexInputStream, MultipleStreams) +try +{ + auto make_multi_stream = [] { + auto block1 = std::make_shared(Block{createColumns({ + createColumn({7, 4, 7, 0, 1, 2, 3}, "a"), + })}); + auto block2 = std::make_shared(Block{createColumns({ + createColumn({42, 45, 50, 37}, "a"), + })}); + return ConcatSkippableBlockInputStream::create( + {NopSkippableBlockInputStream::wrap(block1), NopSkippableBlockInputStream::wrap(block2)}, + {7, 4}, + nullptr); + }; + + // VectorIndexInputStream does not need this information, but ctx needs at least a correct vec column. + auto ann = std::make_shared(); + ann->set_column_id(1); + auto ctx = VectorIndexStreamCtx::createForStableOnlyTests( + ann, + std::make_shared( + ColumnDefines{ColumnDefine{1, "vec", tests::typeFromString("Array(Float32)")}})); + + auto stream = VectorIndexInputStream::create(ctx, std::make_shared(11, true), make_multi_stream()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + {"a"}, + createColumns({ + createColumn({7, 4, 7, 0, 1, 2, 3, 42, 45, 50, 37}), + })); + + stream = VectorIndexInputStream::create(ctx, std::make_shared(11, false), make_multi_stream()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + {"a"}, + createColumns({ + createColumn({}), + })); + + auto filter = std::make_shared(11, false); + filter->set(4, 1); + stream = VectorIndexInputStream::create(ctx, filter, make_multi_stream()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + {"a"}, + createColumns({ + createColumn({1}), + })); + + filter = std::make_shared(11, false); + filter->set(4, 1); + filter->set(0, 1); + stream = VectorIndexInputStream::create(ctx, filter, make_multi_stream()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + {"a"}, + createColumns({ + createColumn({7, 1}), + })); + + filter = std::make_shared(11, false); + filter->set(9, 1); + stream = VectorIndexInputStream::create(ctx, filter, make_multi_stream()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + {"a"}, + createColumns({ + createColumn({50}), + })); + + filter = std::make_shared(11, false); + filter->set(4, 1); + filter->set(9, 1); + stream = VectorIndexInputStream::create(ctx, filter, make_multi_stream()); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + {"a"}, + createColumns({ + createColumn({1, 50}), + })); +} +CATCH + class VectorIndexDMFileTest : public VectorIndexTestUtils , public DB::base::TiFlashStorageTestBasic @@ -233,22 +379,22 @@ try // Read with exact match { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.5})); - DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) + + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.5}, .top_k = 1}), + std::make_shared(read_cols)); + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(3, true)), createColumnNames(), createColumnData({ createColumn({2}), @@ -258,22 +404,21 @@ try // Read with approximate match { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.8}, .top_k = 1}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(3, true)), createColumnNames(), createColumnData({ createColumn({2}), @@ -283,22 +428,21 @@ try // Read multiple rows { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(2); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.8}, .top_k = 2}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(3, true)), createColumnNames(), createColumnData({ createColumn({0, 2}), @@ -308,25 +452,24 @@ try // Read with MVCC filter { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - auto bitmap_filter = std::make_shared(3, true); bitmap_filter->set(/* start */ 2, /* limit */ 1, false); + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.8}, .top_k = 1}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView(bitmap_filter, 0, 3)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + bitmap_filter), createColumnNames(), createColumnData({ createColumn({0}), @@ -336,22 +479,21 @@ try // Query Top K = 0: the pack should be filtered out { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(0); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.8}, .top_k = 0}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(3, true)), createColumnNames(), createColumnData({ createColumn({}), @@ -361,22 +503,21 @@ try // Query Top K > rows { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(10); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.8}, .top_k = 10}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(3, true)), createColumnNames(), createColumnData({ createColumn({0, 1, 2}), @@ -386,25 +527,24 @@ try // Illegal ANNQueryInfo: Ref Vector'dimension is different { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(10); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0}, .top_k = 10}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); - + auto stream2 = VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(3, true)); try { - stream->readPrefix(); - stream->read(); + stream2->readPrefix(); + stream2->read(); FAIL(); } catch (const DB::Exception & ex) @@ -418,53 +558,51 @@ try } } - // Illegal ANNQueryInfo: Referencing a non-existed column. This simply cause vector index not used. - // The query will not fail, because ANNQueryInfo is passed globally in the whole read path. + // Illegal ANNQueryInfo: Referencing a non-existed column (and the column is not in the read schema). + // This will throw exceptions. { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(5); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - - DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) - .build( - dm_file, - read_cols, - RowKeyRanges{RowKeyRange::newAll(false, 1)}, - std::make_shared()); - ASSERT_INPUTSTREAM_COLS_UR( - stream, - createColumnNames(), - createColumnData({ - createColumn({0, 1, 2}), - createVecFloat32Column({{1.0, 2.0, 3.0}, {0.0, 0.0, 0.0}, {1.0, 2.0, 3.5}}), - })); + try + { + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.8}, .top_k = 1, .column_id = 5}), + std::make_shared(read_cols)); + FAIL(); + } + catch (const DB::Exception & ex) + { + EXPECT_TRUE(ex.message().find("Check vec_cd.has_value() failed") != std::string::npos) << ex.message(); + } + catch (...) + { + FAIL(); + } } // Illegal ANNQueryInfo: Different distance metric. { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::COSINE); - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({ + .vec = {1.0, 2.0, 3.8}, + .top_k = 1, + .distance_metric = tipb::VectorDistanceMetric::COSINE, + }), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); + auto stream2 = VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(3, true)); try { - stream->readPrefix(); - stream->read(); + stream2->readPrefix(); + stream2->read(); FAIL(); } catch (const DB::Exception & ex) @@ -483,23 +621,27 @@ try // Illegal ANNQueryInfo: The column exists but is not a vector column. // Currently the query is fine and ANNQueryInfo is discarded, because we discovered that this column // does not have index at all. + if (!test_only_vec_column) { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(MutSup::extra_handle_id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); + // Note, in test_only_vec_column mode, column_id becomes a not-found column in read_cols + // so that an exception will be raised instead. This case is already checked before. + // So here we only check with test_only_vec_column==false. + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.8}, .top_k = 1, .column_id = MutSup::extra_handle_id}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(3, true)), createColumnNames(), createColumnData({ createColumn({0, 1, 2}), @@ -535,38 +677,32 @@ try dm_file = restoreDMFile(); auto index_infos = std::make_shared(LocalIndexInfos{ // index with index_id == 3 - LocalIndexInfo{ - .kind = TiDB::ColumnarIndexKind::Vector, - .index_id = 3, - .column_id = vec_column_id, - .def_vector_index = std::make_shared(TiDB::VectorIndexDefinition{ + LocalIndexInfo( + 3, + vec_column_id, + std::make_shared(TiDB::VectorIndexDefinition{ .kind = tipb::VectorIndexKind::HNSW, .dimension = 3, .distance_metric = tipb::VectorDistanceMetric::L2, - }), - }, + })), // index with index_id == 4 - LocalIndexInfo{ - .kind = TiDB::ColumnarIndexKind::Vector, - .index_id = 4, - .column_id = vec_column_id, - .def_vector_index = std::make_shared(TiDB::VectorIndexDefinition{ + LocalIndexInfo( + 4, + vec_column_id, + std::make_shared(TiDB::VectorIndexDefinition{ .kind = tipb::VectorIndexKind::HNSW, .dimension = 3, .distance_metric = tipb::VectorDistanceMetric::COSINE, - }), - }, + })), // index with index_id == EmptyIndexID, column_id = vec_column_id - LocalIndexInfo{ - .kind = TiDB::ColumnarIndexKind::Vector, - .index_id = EmptyIndexID, - .column_id = vec_column_id, - .def_vector_index = std::make_shared(TiDB::VectorIndexDefinition{ + LocalIndexInfo( + EmptyIndexID, + vec_column_id, + std::make_shared(TiDB::VectorIndexDefinition{ .kind = tipb::VectorIndexKind::HNSW, .dimension = 3, .distance_metric = tipb::VectorDistanceMetric::L2, - }), - }, + })), }); dm_file = buildMultiIndex(index_infos); @@ -581,23 +717,21 @@ try // Read with approximate match { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_index_id(3); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.8}, .top_k = 1, .index_id = 3}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(3, true)), createColumnNames(), createColumnData({ createColumn({2}), @@ -607,23 +741,21 @@ try // Read multiple rows { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_index_id(3); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(2); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.8}, .top_k = 2, .index_id = 3}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(3, true)), createColumnNames(), createColumnData({ createColumn({0, 2}), @@ -633,26 +765,24 @@ try // Read with MVCC filter { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_index_id(3); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - auto bitmap_filter = std::make_shared(3, true); bitmap_filter->set(/* start */ 2, /* limit */ 1, false); + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.8}, .top_k = 1, .index_id = 3}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView(bitmap_filter, 0, 3)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + bitmap_filter), createColumnNames(), createColumnData({ createColumn({0}), @@ -666,23 +796,26 @@ try // Read with approximate match { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_index_id(4); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::COSINE); - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({ + .vec = {1.0, 2.0, 3.8}, + .top_k = 1, + .index_id = 4, + .distance_metric = tipb::VectorDistanceMetric::COSINE, + }), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(3, true)), createColumnNames(), createColumnData({ createColumn({2}), @@ -692,23 +825,26 @@ try // Read multiple rows { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_index_id(4); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::COSINE); - ann_query_info->set_top_k(2); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({ + .vec = {1.0, 2.0, 3.8}, + .top_k = 2, + .index_id = 4, + .distance_metric = tipb::VectorDistanceMetric::COSINE, + }), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(3, true)), createColumnNames(), createColumnData({ createColumn({0, 2}), @@ -718,26 +854,29 @@ try // Read with MVCC filter { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_index_id(4); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::COSINE); - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - auto bitmap_filter = std::make_shared(3, true); bitmap_filter->set(/* start */ 2, /* limit */ 1, false); + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({ + .vec = {1.0, 2.0, 3.8}, + .top_k = 1, + .index_id = 4, + .distance_metric = tipb::VectorDistanceMetric::COSINE, + }), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView(bitmap_filter, 0, 3)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + bitmap_filter), createColumnNames(), createColumnData({ createColumn({0}), @@ -752,22 +891,21 @@ try // Read with approximate match { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.8}, .top_k = 1}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(3, true)), createColumnNames(), createColumnData({ createColumn({2}), @@ -777,22 +915,21 @@ try // Read multiple rows { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(2); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.8}, .top_k = 2}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(3, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(3, true)), createColumnNames(), createColumnData({ createColumn({0, 2}), @@ -802,25 +939,24 @@ try // Read with MVCC filter { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.8})); - auto bitmap_filter = std::make_shared(3, true); bitmap_filter->set(/* start */ 2, /* limit */ 1, false); + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.8}, .top_k = 1}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView(bitmap_filter, 0, 3)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + bitmap_filter), createColumnNames(), createColumnData({ createColumn({0}), @@ -869,15 +1005,11 @@ try dm_file = buildIndex(*vector_index); { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(4); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.5})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.5}, .top_k = 4}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(5, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, @@ -885,7 +1017,10 @@ try std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(5, true)), createColumnNames(), createColumnData({ createColumn({0, 1, 3, 4}), @@ -937,22 +1072,21 @@ try // Pack #0 is filtered out according to VecIndex { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({5.0, 5.0, 5.5})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {5.0, 5.0, 5.5}, .top_k = 1}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(6, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(6, true)), createColumnNames(), createColumnData({ createColumn({3}), @@ -962,22 +1096,21 @@ try // Pack #1 is filtered out according to VecIndex { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.0})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {1.0, 2.0, 3.0}, .top_k = 1}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(6, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(6, true)), createColumnNames(), createColumnData({ createColumn({0}), @@ -987,22 +1120,21 @@ try // Both packs are reserved { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(2); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({0.0, 0.0, 0.0})); - + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {0.0, 0.0, 0.0}, .top_k = 2}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView::createWithFilter(6, true)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + std::make_shared(6, true)), createColumnNames(), createColumnData({ createColumn({1, 5}), @@ -1012,25 +1144,24 @@ try // Pack Filter + MVCC (the matching row #5 is marked as filtered out by MVCC) { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(2); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({0.0, 0.0, 0.0})); - auto bitmap_filter = std::make_shared(6, true); bitmap_filter->set(/* start */ 5, /* limit */ 1, false); + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {0.0, 0.0, 0.0}, .top_k = 2}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView(bitmap_filter, 0, 6)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build( dm_file, read_cols, RowKeyRanges{RowKeyRange::newAll(false, 1)}, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + bitmap_filter), createColumnNames(), createColumnData({ createColumn({0, 1}), @@ -1080,24 +1211,23 @@ try // Pack Filter using RowKeyRange { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(1); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({8.0})); - // This row key range will cause pack#0 and pack#1 reserved, and pack#2 filtered out. auto row_key_ranges = RowKeyRanges{RowKeyRange::fromHandleRange(HandleRange(0, 5))}; auto bitmap_filter = std::make_shared(9, false); bitmap_filter->set(0, 6); // 0~6 rows are valid, 6~9 rows are invalid due to pack filter. + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {8.0}, .top_k = 1}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView(bitmap_filter, 0, 9)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build(dm_file, read_cols, row_key_ranges, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + bitmap_filter), createColumnNames(), createColumnData({ createColumn({5}), @@ -1105,13 +1235,17 @@ try })); // TopK=4 - ann_query_info->set_top_k(4); + vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {8.0}, .top_k = 4}), + std::make_shared(read_cols)); builder = DMFileBlockInputStreamBuilder(dbContext()); - stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView(bitmap_filter, 0, 9)) + stream = builder.setVecIndexQuery(vec_idx_ctx) .build(dm_file, read_cols, row_key_ranges, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + bitmap_filter), createColumnNames(), createColumnData({ createColumn({2, 3, 4, 5}), @@ -1121,12 +1255,6 @@ try // Pack Filter + Bitmap Filter { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_cd.id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(3); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32({8.0})); - // This row key range will cause pack#0 and pack#1 reserved, and pack#2 filtered out. auto row_key_ranges = RowKeyRanges{RowKeyRange::fromHandleRange(HandleRange(0, 5))}; @@ -1135,12 +1263,17 @@ try bitmap_filter->set(0, 2); bitmap_filter->set(3, 2); + auto vec_idx_ctx = VectorIndexStreamCtx::createForStableOnlyTests( + annQueryInfoTopK({.vec = {8.0}, .top_k = 3}), + std::make_shared(read_cols)); DMFileBlockInputStreamBuilder builder(dbContext()); - auto stream = builder.setAnnQureyInfo(ann_query_info) - .setBitmapFilter(BitmapFilterView(bitmap_filter, 0, 9)) + auto stream = builder.setVecIndexQuery(vec_idx_ctx) .build(dm_file, read_cols, row_key_ranges, std::make_shared()); ASSERT_INPUTSTREAM_COLS_UR( - stream, + VectorIndexTestUtils::wrapVectorStream( // + vec_idx_ctx, + stream, + bitmap_filter), createColumnNames(), createColumnData({ createColumn({1, 3, 4}), @@ -1171,12 +1304,7 @@ class VectorIndexSegmentTestBase UInt32 top_k, const std::vector & ref_vec) { - auto ann_query_info = std::make_shared(); - ann_query_info->set_column_id(vec_column_id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(top_k); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32(ref_vec)); - return read(segment_id, begin, end, columns_to_read, ann_query_info); + return read(segment_id, begin, end, columns_to_read, annQueryInfoTopK({.vec = ref_vec, .top_k = top_k})); } BlockInputStreamPtr annQuery( @@ -1196,7 +1324,7 @@ class VectorIndexSegmentTestBase ColumnDefines columns_to_read, ANNQueryInfoPtr ann_query) { - auto range = buildRowKeyRange(begin, end); + auto range = buildRowKeyRange(begin, end, /*is_common_handle*/ false); auto [segment, snapshot] = getSegmentForRead(segment_id); // load DMilePackFilterResult for each DMFile DMFilePackFilterResults pack_filter_results; @@ -1228,9 +1356,15 @@ class VectorIndexSegmentTestBase ColumnDefine cdPK() { return getExtraHandleColumnDefine(options.is_common_handle); } protected: - Block prepareWriteBlockImpl(Int64 start_key, Int64 end_key, bool is_deleted) override - { - auto block = SegmentTestBasic::prepareWriteBlockImpl(start_key, end_key, is_deleted); + Block prepareWriteBlockImpl( + Int64 start_key, + Int64 end_key, + bool is_deleted, + bool including_right_boundary, + std::optional ts) override + { + auto block + = SegmentTestBasic::prepareWriteBlockImpl(start_key, end_key, is_deleted, including_right_boundary, ts); block.insert(colVecFloat32(fmt::format("[{}, {})", start_key, end_key), vec_column_name, vec_column_id)); return block; } @@ -1445,20 +1579,21 @@ try ensureSegmentStableLocalIndex(DELTA_MERGE_FIRST_SEGMENT_ID, indexInfo()); writeSegment(DELTA_MERGE_FIRST_SEGMENT_ID, 10, /* at */ 20); + writeSegmentWithDeleteRange(DELTA_MERGE_FIRST_SEGMENT_ID, /* begin */ 25, /* end */ 27, false, false); // ANNQuery will be only effective to Stable layer. All delta data will be returned. auto stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 1, {100.0}); - assertStreamOut(stream, "[4, 5)|[20, 30)"); + assertStreamOut(stream, "[4, 5)|[20, 25)|[27, 30)"); stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 2, {10.0}); - assertStreamOut(stream, "[3, 5)|[20, 30)"); + assertStreamOut(stream, "[3, 5)|[20, 25)|[27, 30)"); stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 5, {10.0}); - assertStreamOut(stream, "[0, 5)|[20, 30)"); + assertStreamOut(stream, "[0, 5)|[20, 25)|[27, 30)"); stream = annQuery(DELTA_MERGE_FIRST_SEGMENT_ID, createQueryColumns(), 10, {10.0}); - assertStreamOut(stream, "[0, 5)|[20, 30)"); + assertStreamOut(stream, "[0, 5)|[20, 25)|[27, 30)"); } CATCH @@ -1666,9 +1801,19 @@ class VectorIndexSegmentExtraColumnTest return ColumnDefine(extra_column_id, extra_column_name, tests::typeFromString("Int64")); } - Block prepareWriteBlockImpl(Int64 start_key, Int64 end_key, bool is_deleted) override + Block prepareWriteBlockImpl( + Int64 start_key, + Int64 end_key, + bool is_deleted, + bool including_right_boundary, + std::optional ts) override { - auto block = VectorIndexSegmentTestBase::prepareWriteBlockImpl(start_key, end_key, is_deleted); + auto block = VectorIndexSegmentTestBase::prepareWriteBlockImpl( + start_key, + end_key, + is_deleted, + including_right_boundary, + ts); block.insert( colInt64(fmt::format("[{}, {})", start_key + 1000, end_key + 1000), extra_column_name, extra_column_id)); return block; @@ -1817,7 +1962,7 @@ class VectorIndexSegmentOnS3Test static ColumnDefine cdPK() { return getExtraHandleColumnDefine(false); } - BlockInputStreamPtr createComputeNodeStream( + std::pair createComputeNodeStream( const SegmentPtr & write_node_segment, const ColumnDefines & columns_to_read, const PushDownExecutorPtr & filter, @@ -1861,7 +2006,7 @@ class VectorIndexSegmentOnS3Test std::numeric_limits::max(), DEFAULT_BLOCK_SIZE); - return stream; + return {stream, read_dm_context}; } static void removeAllFileCache() @@ -1897,38 +2042,32 @@ class VectorIndexSegmentOnS3Test auto dm_files = segment->getStable()->getDMFiles(); auto index_infos = std::make_shared(LocalIndexInfos{ // index with index_id == 3 - LocalIndexInfo{ - .kind = TiDB::ColumnarIndexKind::Vector, - .index_id = 3, - .column_id = vec_column_id, - .def_vector_index = std::make_shared(TiDB::VectorIndexDefinition{ + LocalIndexInfo( + 3, + vec_column_id, + std::make_shared(TiDB::VectorIndexDefinition{ .kind = tipb::VectorIndexKind::HNSW, .dimension = 1, .distance_metric = tipb::VectorDistanceMetric::L2, - }), - }, + })), // index with index_id == 4 - LocalIndexInfo{ - .kind = TiDB::ColumnarIndexKind::Vector, - .index_id = 4, - .column_id = vec_column_id, - .def_vector_index = std::make_shared(TiDB::VectorIndexDefinition{ + LocalIndexInfo( + 4, + vec_column_id, + std::make_shared(TiDB::VectorIndexDefinition{ .kind = tipb::VectorIndexKind::HNSW, .dimension = 1, .distance_metric = tipb::VectorDistanceMetric::COSINE, - }), - }, + })), // index with index_id == EmptyIndexID, column_id = vec_column_id - LocalIndexInfo{ - .kind = TiDB::ColumnarIndexKind::Vector, - .index_id = EmptyIndexID, - .column_id = vec_column_id, - .def_vector_index = std::make_shared(TiDB::VectorIndexDefinition{ + LocalIndexInfo( + EmptyIndexID, + vec_column_id, + std::make_shared(TiDB::VectorIndexDefinition{ .kind = tipb::VectorIndexKind::HNSW, .dimension = 1, .distance_metric = tipb::VectorDistanceMetric::L2, - }), - }, + })), }); auto build_info = DMFileLocalIndexWriter::getLocalIndexBuildInfo(index_infos, dm_files); @@ -1952,30 +2091,26 @@ class VectorIndexSegmentOnS3Test return new_segment; } - BlockInputStreamPtr computeNodeTableScan() + std::pair computeNodeTableScan() { return createComputeNodeStream(wn_segment, {cdPK(), cdVec()}, nullptr); } - BlockInputStreamPtr computeNodeANNQuery( + std::pair computeNodeANNQuery( const std::vector ref_vec, IndexID index_id, UInt32 top_k = 1, const ScanContextPtr & read_scan_context = nullptr) { - auto ann_query_info = std::make_shared(); - ann_query_info->set_index_id(index_id); - ann_query_info->set_column_id(vec_column_id); - ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); - ann_query_info->set_top_k(top_k); - ann_query_info->set_ref_vec_f32(encodeVectorFloat32(ref_vec)); - - auto stream = createComputeNodeStream( + return createComputeNodeStream( wn_segment, {cdPK(), cdVec()}, - std::make_shared(ann_query_info), + std::make_shared(annQueryInfoTopK({ + .vec = ref_vec, + .top_k = top_k, + .index_id = index_id, + })), read_scan_context); - return stream; } protected: @@ -2031,7 +2166,7 @@ try prepareWriteNodeStable(); FileCache::shutdown(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID); try { @@ -2059,7 +2194,7 @@ try ASSERT_EQ(0, file_cache->getAll().size()); } { - auto stream = computeNodeTableScan(); + auto [stream, rn_dm_ctx] = computeNodeTableScan(); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2086,7 +2221,7 @@ try } { auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2108,7 +2243,7 @@ try // Read again, we should be reading from memory cache. auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2137,7 +2272,7 @@ try IndexID query_index_id = EmptyIndexID; { auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, query_index_id, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, query_index_id, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2159,7 +2294,7 @@ try // Read again, we should be reading from memory cache. auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, query_index_id, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, query_index_id, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2178,7 +2313,7 @@ try IndexID query_index_id = 3; { auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, query_index_id, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, query_index_id, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2200,7 +2335,7 @@ try // Read again, we should be reading from memory cache. auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, query_index_id, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, query_index_id, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2227,7 +2362,7 @@ try } { auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2258,7 +2393,7 @@ try { // When cache is evicted (but memory cache exists), the query should be fine. auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2275,7 +2410,7 @@ try // Read again, we should be reading from memory cache. auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2301,7 +2436,7 @@ try } { auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2337,7 +2472,7 @@ try { // When cache is evicted (and memory cache is dropped), the query should be fine. auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2354,7 +2489,7 @@ try // Read again, we should be reading from memory cache. auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2380,7 +2515,7 @@ try } { auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2404,7 +2539,7 @@ try { // Query should be fine. auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2421,7 +2556,7 @@ try // Read again, we should be reading from memory cache. auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2447,7 +2582,7 @@ try } { auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2477,7 +2612,7 @@ try { // Query should be fine. auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2494,7 +2629,7 @@ try // Read again, we should be reading from memory cache. auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2524,7 +2659,7 @@ try auto th_1 = std::async([&]() { auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2546,7 +2681,7 @@ try auto th_2 = std::async([&]() { auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({7.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({7.0}, EmptyIndexID, 1, scan_context); ASSERT_INPUTSTREAM_COLS_UR( stream, Strings({DMTestEnv::pk_name, vec_column_name}), @@ -2589,7 +2724,7 @@ try } { auto scan_context = std::make_shared(); - auto stream = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); + auto [stream, rn_dm_ctx] = computeNodeANNQuery({5.0}, EmptyIndexID, 1, scan_context); ASSERT_THROW( { diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index_utils.h b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index_utils.h index d91cbad6f32..69e0f911f46 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index_utils.h +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index_utils.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -66,6 +67,28 @@ class VectorIndexTestUtils return wb.str(); } + struct AnnQueryInfoTopKOptions + { + std::vector vec; + UInt32 top_k; + Int64 column_id = 100; // vec_column_id + Int64 index_id = 0; + tipb::VectorDistanceMetric distance_metric = tipb::VectorDistanceMetric::L2; + }; + + static ANNQueryInfoPtr annQueryInfoTopK(AnnQueryInfoTopKOptions options) + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_query_type(tipb::ANNQueryType::OrderBy); + ann_query_info->set_column_id(options.column_id); + ann_query_info->set_distance_metric(options.distance_metric); + ann_query_info->set_top_k(options.top_k); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32(options.vec)); + if (options.index_id != 0) + ann_query_info->set_index_id(options.index_id); + return ann_query_info; + } + ColumnDefine cdVec() const { // When used in read, no need to assign vector_index. @@ -82,18 +105,25 @@ class VectorIndexTestUtils .kind = tipb::VectorIndexKind::HNSW, .dimension = 1, .distance_metric = tipb::VectorDistanceMetric::L2, - }) + }) const { const LocalIndexInfos index_infos = LocalIndexInfos{ - LocalIndexInfo{ - .kind = TiDB::ColumnarIndexKind::Vector, - .index_id = EmptyIndexID, - .column_id = vec_column_id, - .def_vector_index = std::make_shared(definition), - }, + LocalIndexInfo(EmptyIndexID, vec_column_id, std::make_shared(definition)), }; return std::make_shared(index_infos); } + + static auto wrapVectorStream( + const VectorIndexStreamCtxPtr & ctx, + const SkippableBlockInputStreamPtr & inner, + const BitmapFilterPtr & filter) + { + auto stream = ConcatSkippableBlockInputStream::create( + /* inputs */ {inner}, + /* rows */ {filter->size()}, + /* ScanContext */ nullptr); + return VectorIndexInputStream::create(ctx, filter, stream); + } }; class DeltaMergeStoreVectorBase : public VectorIndexTestUtils diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_inverted_index.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_inverted_index.cpp new file mode 100644 index 00000000000..594d2eca5f2 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_inverted_index.cpp @@ -0,0 +1,302 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include + + +namespace DB::DM::tests +{ + +template +class InvertedIndexTest +{ +public: + static constexpr auto IndexFileName = "test.inverted_index"; + + InvertedIndexTest() = default; + + ~InvertedIndexTest() = default; + + static void writeBlock( + InvertedIndexWriterOnDisk & builder, + const DB::tests::InferredDataVector & values, + const DB::tests::InferredDataVector & del_marks) + { + auto col = DB::tests::createColumn(values).column; + auto del_mark_col = DB::tests::createColumn(del_marks).column; + const auto * del_mark = static_cast *>(del_mark_col.get()); + builder.addBlock(*col, del_mark, []() { return true; }); + } + + class SimpleTestCase + { + static void search(const InvertedIndexReaderPtr & viewer) + { + auto v_search = [&viewer](const UInt64 key, const String & expected) { + auto bitmap_filter = std::make_shared(30, false); + viewer->search(bitmap_filter, key); + ASSERT_EQ(bitmap_filter->toDebugString(), expected); + }; + v_search(1, "100000000010000000001000000000"); + v_search(2, "000000000001000000000111000000"); + v_search(3, "001000000000100000000000111000"); + v_search(4, "000000000000010000000000000111"); + v_search(5, "000010000000001000000000000000"); + v_search(6, "000000000000000100000000000000"); + v_search(7, "000000100000000010000000000000"); + v_search(8, "000000000000000001000000000000"); + v_search(9, "000000001000000000100000000000"); + v_search(10, "000000000000000000010000000000"); + v_search(11, "000000000000000000000000000000"); + + auto v_search_range = [&viewer](const UInt64 start, const UInt64 end, const String & expected) { + auto bitmap_filter = std::make_shared(30, false); + viewer->searchRange(bitmap_filter, start, end); + ASSERT_EQ(bitmap_filter->toDebugString(), expected); + }; + v_search_range(1, 2, "100000000011000000001111000000"); + v_search_range(2, 3, "001000000001100000000111111000"); + v_search_range(10, 10, "000000000000000000010000000000"); + v_search_range(1, 11, "101010101011111111111111111111"); + } + + public: + static void run() + { + { + auto builder = InvertedIndexWriterOnDisk(0, IndexFileName); + writeBlock(builder, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 1, 0, 1, 0, 1, 0, 1, 0, 1}); + writeBlock(builder, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + writeBlock(builder, {1, 2, 2, 2, 3, 3, 3, 4, 4, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + builder.finalize(); + } + { + auto viewer = std::make_shared>(IndexFileName); + search(viewer); + } + { + auto viewer = std::make_shared>(IndexFileName); + search(viewer); + } + Poco::File(IndexFileName).remove(); + } + }; + + + class LargeTestCase + { + static constexpr UInt32 block_size = 10000; + static constexpr T block_count = 100; + + static void search(const InvertedIndexReaderPtr & viewer) + { + auto v_search = [&viewer](const UInt64 key, const size_t expected_count) { + auto bitmap_filter = std::make_shared(block_size * block_count, false); + viewer->search(bitmap_filter, key); + ASSERT_EQ(bitmap_filter->count(), expected_count); + }; + std::mt19937 generator; + { + std::uniform_int_distribution distribution(0, block_count); + for (UInt32 i = 0; i < 10; ++i) + v_search(distribution(generator), block_size); + } + { + std::uniform_int_distribution distribution( + std::numeric_limits::min(), + std::numeric_limits::max()); + for (UInt32 i = 0; i < 10; ++i) + { + auto random_v = distribution(generator); + v_search(random_v, (random_v >= block_count || random_v < 0) ? 0 : block_size); + } + } + auto v_search_range = [&viewer](const UInt64 start, const UInt64 end, const size_t expected_count) { + auto bitmap_filter = std::make_shared(block_size * block_count, false); + viewer->searchRange(bitmap_filter, start, end); + ASSERT_EQ(bitmap_filter->count(), expected_count); + }; + + v_search_range(1, 2, 2 * block_size); + v_search_range(2, 3, 2 * block_size); + v_search_range(10, 10, block_size); + v_search_range(71, 72, 2 * block_size); + v_search_range(1, 99, 99 * block_size); + v_search_range(0, 100, 100 * block_size); + v_search_range(99, 104, 1 * block_size); + v_search_range(100, 104, 0); + } + + static void searchMultiThread(const InvertedIndexReaderPtr & viewer) + { + auto v_search = [&viewer](const UInt64 key, const size_t expected_count) { + auto bitmap_filter = std::make_shared(block_size * block_count, false); + viewer->search(bitmap_filter, key); + ASSERT_EQ(bitmap_filter->count(), expected_count); + }; + std::mt19937 generator; + std::uniform_int_distribution distribution_bc(0, block_count); + std::vector threads; + { + for (UInt32 i = 0; i < 10; ++i) + { + threads.emplace_back([&v_search, &generator, &distribution_bc]() { + auto random_v = distribution_bc(generator); + v_search(random_v, block_size); + }); + } + } + std::uniform_int_distribution distribution_fullrange( + std::numeric_limits::min(), + std::numeric_limits::max()); + { + for (UInt32 i = 0; i < 10; ++i) + { + threads.emplace_back([&v_search, &generator, &distribution_fullrange]() { + auto random_v = distribution_fullrange(generator); + v_search(random_v, (random_v >= block_count || random_v < 0) ? 0 : block_size); + }); + } + } + + auto v_search_range = [&viewer](const UInt64 start, const UInt64 end, const size_t expected_count) { + auto bitmap_filter = std::make_shared(block_size * block_count, false); + viewer->searchRange(bitmap_filter, start, end); + ASSERT_EQ(bitmap_filter->count(), expected_count); + }; + threads.emplace_back([&v_search_range]() { v_search_range(1, 2, 2 * block_size); }); + threads.emplace_back([&v_search_range]() { v_search_range(2, 3, 2 * block_size); }); + threads.emplace_back([&v_search_range]() { v_search_range(10, 10, block_size); }); + threads.emplace_back([&v_search_range]() { v_search_range(71, 72, 2 * block_size); }); + threads.emplace_back([&v_search_range]() { v_search_range(1, 99, 99 * block_size); }); + threads.emplace_back([&v_search_range]() { v_search_range(0, 100, 100 * block_size); }); + threads.emplace_back([&v_search_range]() { v_search_range(99, 104, 1 * block_size); }); + threads.emplace_back([&v_search_range]() { v_search_range(100, 104, 0); }); + + for (auto & thread : threads) + thread.join(); + } + + public: + static void run() + { + { + auto builder = InvertedIndexWriterOnDisk(0, IndexFileName); + for (UInt32 i = 0; i < block_count; ++i) + { + DB::tests::InferredDataVector values(block_size, i); + DB::tests::InferredDataVector del_marks(block_size, 0); + writeBlock(builder, values, del_marks); + } + builder.finalize(); + } + { + auto viewer = std::make_shared>(IndexFileName); + search(viewer); + } + { + auto viewer = std::make_shared>(IndexFileName); + search(viewer); + } + Poco::File(IndexFileName).remove(); + } + + static void runMultiThread() + { + { + auto builder = InvertedIndexWriterOnDisk(0, IndexFileName); + for (UInt32 i = 0; i < block_count; ++i) + { + DB::tests::InferredDataVector values(block_size, i); + DB::tests::InferredDataVector del_marks(block_size, 0); + writeBlock(builder, values, del_marks); + } + builder.finalize(); + } + { + auto viewer = std::make_shared>(IndexFileName); + searchMultiThread(viewer); + } + { + auto viewer = std::make_shared>(IndexFileName); + searchMultiThread(viewer); + } + Poco::File(IndexFileName).remove(); + } + }; +}; + +TEST(InvertedIndex, Simple) +try +{ + InvertedIndexTest::SimpleTestCase::run(); + InvertedIndexTest::SimpleTestCase::run(); + InvertedIndexTest::SimpleTestCase::run(); + InvertedIndexTest::SimpleTestCase::run(); + InvertedIndexTest::SimpleTestCase::run(); + InvertedIndexTest::SimpleTestCase::run(); + InvertedIndexTest::SimpleTestCase::run(); + InvertedIndexTest::SimpleTestCase::run(); +} +CATCH + +// Split the large test case into two parts to avoid long running time. + +TEST(InvertedIndex, Large1) +try +{ + InvertedIndexTest::LargeTestCase::run(); + InvertedIndexTest::LargeTestCase::run(); + InvertedIndexTest::LargeTestCase::run(); + InvertedIndexTest::LargeTestCase::run(); +} +CATCH + +TEST(InvertedIndex, Large2) +try +{ + InvertedIndexTest::LargeTestCase::run(); + InvertedIndexTest::LargeTestCase::run(); + InvertedIndexTest::LargeTestCase::run(); + InvertedIndexTest::LargeTestCase::run(); +} +CATCH + +TEST(InvertedIndex, MultipleThreads1) +try +{ + InvertedIndexTest::LargeTestCase::runMultiThread(); + InvertedIndexTest::LargeTestCase::runMultiThread(); + InvertedIndexTest::LargeTestCase::runMultiThread(); + InvertedIndexTest::LargeTestCase::runMultiThread(); +} +CATCH + +TEST(InvertedIndex, MultipleThreads2) +try +{ + InvertedIndexTest::LargeTestCase::runMultiThread(); + InvertedIndexTest::LargeTestCase::runMultiThread(); + InvertedIndexTest::LargeTestCase::runMultiThread(); + InvertedIndexTest::LargeTestCase::runMultiThread(); +} +CATCH + +} // namespace DB::DM::tests diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_bitmap.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_bitmap.cpp index aadbf595707..5831561028d 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_bitmap.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_bitmap.cpp @@ -17,21 +17,30 @@ #include #include #include +#include #include #include #include #include #include - using namespace std::chrono_literals; using namespace DB::tests; namespace DB::DM::tests { -class SegmentBitmapFilterTest : public SegmentTestBasic +class SegmentBitmapFilterTest + : public SegmentTestBasic + , public testing::WithParamInterface { +public: + void SetUp() override + { + is_common_handle = GetParam(); + SegmentTestBasic::SetUp(SegmentTestOptions{.is_common_handle = is_common_handle}); + } + protected: DB::LoggerPtr log = DB::Logger::get("SegmentBitmapFilterTest"); static constexpr auto SEG_ID = DELTA_MERGE_FIRST_SEGMENT_ID; @@ -39,11 +48,21 @@ class SegmentBitmapFilterTest : public SegmentTestBasic ColumnPtr hold_handle; RowKeyRanges read_ranges; - void setRowKeyRange(Int64 begin, Int64 end) + void setRowKeyRange(Int64 begin, Int64 end, bool including_right_boundary) { auto itr = segments.find(SEG_ID); RUNTIME_CHECK(itr != segments.end(), SEG_ID); - itr->second->rowkey_range = buildRowKeyRange(begin, end); + itr->second->rowkey_range = buildRowKeyRange(begin, end, is_common_handle, including_right_boundary); + } + + void writeSegmentGeneric( + std::string_view seg_data, + std::optional> rowkey_range = std::nullopt) + { + if (is_common_handle) + writeSegment(seg_data, rowkey_range); + else + writeSegment(seg_data, rowkey_range); } /* @@ -61,12 +80,16 @@ class SegmentBitmapFilterTest : public SegmentTestBasic Returns {row_id, handle}. */ - std::pair *, const PaddedPODArray *> writeSegment( + template + std::pair *, const std::optional>> writeSegment( std::string_view seg_data, - std::optional> rowkey_range = std::nullopt) + std::optional> rowkey_range = std::nullopt) { if (rowkey_range) - setRowKeyRange(rowkey_range->first, rowkey_range->second); + { + const auto & [left, right, including_right_boundary] = *rowkey_range; + setRowKeyRange(left, right, including_right_boundary); + } auto seg_data_units = parseSegData(seg_data); for (const auto & unit : seg_data_units) { @@ -77,43 +100,42 @@ class SegmentBitmapFilterTest : public SegmentTestBasic if (hold_row_id == nullptr) { RUNTIME_CHECK(hold_handle == nullptr); - return {nullptr, nullptr}; + return {nullptr, std::nullopt}; } else { - RUNTIME_CHECK(hold_handle != nullptr); - return {toColumnVectorDataPtr(hold_row_id), toColumnVectorDataPtr(hold_handle)}; + return {toColumnVectorDataPtr(hold_row_id), ColumnView(*hold_handle)}; } } void writeSegment(const SegDataUnit & unit) { const auto & type = unit.type; - auto [begin, end] = unit.range; - + auto [begin, end, including_right_boundary] = unit.range; + const auto write_count = end - begin + including_right_boundary; if (type == "d_mem") { - SegmentTestBasic::writeSegment(SEG_ID, end - begin, begin); + SegmentTestBasic::writeToCache(SEG_ID, write_count, begin, unit.shuffle, unit.ts); } else if (type == "d_mem_del") { - SegmentTestBasic::writeSegmentWithDeletedPack(SEG_ID, end - begin, begin); + SegmentTestBasic::writeSegmentWithDeletedPack(SEG_ID, write_count, begin); } else if (type == "d_tiny") { - SegmentTestBasic::writeSegment(SEG_ID, end - begin, begin); + SegmentTestBasic::writeSegment(SEG_ID, write_count, begin, unit.shuffle, unit.ts); SegmentTestBasic::flushSegmentCache(SEG_ID); } else if (type == "d_tiny_del") { - SegmentTestBasic::writeSegmentWithDeletedPack(SEG_ID, end - begin, begin); + SegmentTestBasic::writeSegmentWithDeletedPack(SEG_ID, write_count, begin); SegmentTestBasic::flushSegmentCache(SEG_ID); } else if (type == "d_big") { SegmentTestBasic::ingestDTFileIntoDelta( SEG_ID, - end - begin, + write_count, begin, false, unit.pack_size, @@ -121,18 +143,29 @@ class SegmentBitmapFilterTest : public SegmentTestBasic } else if (type == "d_dr") { - SegmentTestBasic::writeSegmentWithDeleteRange(SEG_ID, begin, end); + SegmentTestBasic::writeSegmentWithDeleteRange( + SEG_ID, + begin, + end, + is_common_handle, + including_right_boundary); } else if (type == "s") { - SegmentTestBasic::writeSegment(SEG_ID, end - begin, begin); - if (unit.pack_size) - { - db_context->getSettingsRef().dt_segment_stable_pack_rows = *(unit.pack_size); - reloadDMContext(); - ASSERT_EQ(dm_context->stable_pack_rows, *(unit.pack_size)); - } - SegmentTestBasic::mergeSegmentDelta(SEG_ID); + SegmentTestBasic::writeSegment(SEG_ID, write_count, begin, unit.shuffle, unit.ts); + SegmentTestBasic::mergeSegmentDelta(SEG_ID, /*check_rows*/ true, unit.pack_size); + } + else if (type == "compact_delta") + { + SegmentTestBasic::compactSegmentDelta(SEG_ID); + } + else if (type == "flush_cache") + { + SegmentTestBasic::flushSegmentCache(SEG_ID); + } + else if (type == "merge_delta") + { + SegmentTestBasic::mergeSegmentDelta(SEG_ID, /*check_rows*/ true, unit.pack_size); } else { @@ -142,42 +175,40 @@ class SegmentBitmapFilterTest : public SegmentTestBasic struct TestCase { - TestCase( - std::string_view seg_data_, - size_t expected_size_, - std::string_view expected_row_id_, - std::string_view expected_handle_, - std::optional> rowkey_range_ = std::nullopt) - : seg_data(seg_data_) - , expected_size(expected_size_) - , expected_row_id(expected_row_id_) - , expected_handle(expected_handle_) - , rowkey_range(rowkey_range_) - {} std::string seg_data; size_t expected_size; std::string expected_row_id; std::string expected_handle; - std::optional> rowkey_range; + std::optional> rowkey_range; }; - void runTestCase(TestCase test_case) + void runTestCaseGeneric(TestCase test_case, int caller_line) + { + if (is_common_handle) + runTestCase(test_case, caller_line); + else + runTestCase(test_case, caller_line); + } + + template + void runTestCase(TestCase test_case, int caller_line) { - auto [row_id, handle] = writeSegment(test_case.seg_data, test_case.rowkey_range); + auto info = fmt::format("caller_line={}", caller_line); + auto [row_id, handle] = writeSegment(test_case.seg_data, test_case.rowkey_range); if (test_case.expected_size == 0) { - ASSERT_EQ(nullptr, row_id); - ASSERT_EQ(nullptr, handle); + ASSERT_EQ(nullptr, row_id) << info; + ASSERT_EQ(std::nullopt, handle) << info; } else { - ASSERT_EQ(test_case.expected_size, row_id->size()); + ASSERT_EQ(test_case.expected_size, row_id->size()) << info; auto expected_row_id = genSequence(test_case.expected_row_id); - ASSERT_TRUE(sequenceEqual(expected_row_id, *row_id)); + ASSERT_TRUE(sequenceEqual(expected_row_id, *row_id)) << info; - ASSERT_EQ(test_case.expected_size, handle->size()); - auto expected_handle = genSequence(test_case.expected_handle); - ASSERT_TRUE(sequenceEqual(expected_handle, *handle)); + ASSERT_EQ(test_case.expected_size, handle->size()) << info; + auto expected_handle = genHandleSequence(test_case.expected_handle); + ASSERT_TRUE(sequenceEqual(expected_handle, *handle)) << info; } } @@ -192,176 +223,264 @@ class SegmentBitmapFilterTest : public SegmentTestBasic } return results; } + + void checkHandle(PageIdU64 seg_id, std::string_view seq_ranges, int caller_line) + { + auto info = fmt::format("caller_line={}", caller_line); + auto handle = getSegmentHandle(seg_id, {}); + if (is_common_handle) + { + auto expected_handle = genHandleSequence(seq_ranges); + ASSERT_TRUE(sequenceEqual(expected_handle, ColumnView{*handle})) << info; + } + else + { + auto expected_handle = genHandleSequence(seq_ranges); + ASSERT_TRUE(sequenceEqual(expected_handle, ColumnView{*handle})) << info; + } + } + + bool is_common_handle = false; }; -TEST_F(SegmentBitmapFilterTest, InMemory1) +INSTANTIATE_TEST_CASE_P(MVCC, SegmentBitmapFilterTest, /* is_common_handle */ ::testing::Bool()); + +TEST_P(SegmentBitmapFilterTest, InMemory1) try { - runTestCase(TestCase("d_mem:[0, 1000)", 1000, "[0, 1000)", "[0, 1000)")); + runTestCaseGeneric( + TestCase{ + .seg_data = "d_mem:[0, 1000)", + .expected_size = 1000, + .expected_row_id = "[0, 1000)", + .expected_handle = "[0, 1000)"}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, InMemory2) +TEST_P(SegmentBitmapFilterTest, InMemory2) try { - runTestCase(TestCase{"d_mem:[0, 1000)|d_mem:[0, 1000)", 1000, "[1000, 2000)", "[0, 1000)"}); + runTestCaseGeneric( + TestCase{ + .seg_data = "d_mem:[0, 1000)|d_mem:[0, 1000)", + .expected_size = 1000, + .expected_row_id = "[1000, 2000)", + .expected_handle = "[0, 1000)"}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, InMemory3) +TEST_P(SegmentBitmapFilterTest, InMemory3) try { - runTestCase(TestCase{"d_mem:[0, 1000)|d_mem:[100, 200)", 1000, "[0, 100)|[1000, 1100)|[200, 1000)", "[0, 1000)"}); + runTestCaseGeneric( + TestCase{ + .seg_data = "d_mem:[0, 1000)|d_mem:[100, 200)", + .expected_size = 1000, + .expected_row_id = "[0, 100)|[1000, 1100)|[200, 1000)", + .expected_handle = "[0, 1000)"}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, InMemory4) +TEST_P(SegmentBitmapFilterTest, InMemory4) try { - runTestCase(TestCase{"d_mem:[0, 1000)|d_mem:[-100, 100)", 1100, "[1000, 1200)|[100, 1000)", "[-100, 1000)"}); + runTestCaseGeneric( + TestCase{ + .seg_data = "d_mem:[0, 1000)|d_mem:[-100, 100)", + .expected_size = 1100, + .expected_row_id = "[1000, 1200)|[100, 1000)", + .expected_handle = "[-100, 1000)"}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, InMemory5) +TEST_P(SegmentBitmapFilterTest, InMemory5) try { - runTestCase(TestCase{"d_mem:[0, 1000)|d_mem_del:[0, 1000)", 0, "", ""}); + runTestCaseGeneric( + TestCase{ + .seg_data = "d_mem:[0, 1000)|d_mem_del:[0, 1000)", + .expected_size = 0, + .expected_row_id = "", + .expected_handle = ""}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, InMemory6) +TEST_P(SegmentBitmapFilterTest, InMemory6) try { - runTestCase(TestCase{"d_mem:[0, 1000)|d_mem_del:[100, 200)", 900, "[0, 100)|[200, 1000)", "[0, 100)|[200, 1000)"}); + runTestCaseGeneric( + TestCase{ + .seg_data = "d_mem:[0, 1000)|d_mem_del:[100, 200)", + .expected_size = 900, + .expected_row_id = "[0, 100)|[200, 1000)", + .expected_handle = "[0, 100)|[200, 1000)"}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, InMemory7) +TEST_P(SegmentBitmapFilterTest, InMemory7) try { - runTestCase(TestCase{"d_mem:[0, 1000)|d_mem_del:[-100, 100)", 900, "[100, 1000)", "[100, 1000)"}); + runTestCaseGeneric( + TestCase{ + .seg_data = "d_mem:[0, 1000)|d_mem_del:[-100, 100)", + .expected_size = 900, + .expected_row_id = "[100, 1000)", + .expected_handle = "[100, 1000)"}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, Tiny1) +TEST_P(SegmentBitmapFilterTest, Tiny1) try { - runTestCase(TestCase{"d_tiny:[100, 500)|d_mem:[200, 1000)", 900, "[0, 100)|[400, 1200)", "[100, 1000)"}); + runTestCaseGeneric( + TestCase{ + .seg_data = "d_tiny:[100, 500)|d_mem:[200, 1000)", + .expected_size = 900, + .expected_row_id = "[0, 100)|[400, 1200)", + .expected_handle = "[100, 1000)"}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, TinyDel1) +TEST_P(SegmentBitmapFilterTest, TinyDel1) try { - runTestCase(TestCase{ - "d_tiny:[100, 500)|d_tiny_del:[200, 300)|d_mem:[0, 100)", - 400, - "[500, 600)|[0, 100)|[200, 400)", - "[0, 200)|[300, 500)"}); + runTestCaseGeneric( + TestCase{ + .seg_data = "d_tiny:[100, 500)|d_tiny_del:[200, 300)|d_mem:[0, 100)", + .expected_size = 400, + .expected_row_id = "[500, 600)|[0, 100)|[200, 400)", + .expected_handle = "[0, 200)|[300, 500)"}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, DeleteRange) +TEST_P(SegmentBitmapFilterTest, DeleteRange) try { - runTestCase(TestCase{ - "d_tiny:[100, 500)|d_dr:[250, 300)|d_mem:[240, 290)", - 390, - "[0, 140)|[400, 450)|[200, 400)", - "[100, 290)|[300, 500)"}); + runTestCaseGeneric( + TestCase{ + .seg_data = "d_tiny:[100, 500)|d_dr:[250, 300)|d_mem:[240, 290)", + .expected_size = 390, + .expected_row_id = "[0, 140)|[400, 450)|[200, 400)", + .expected_handle = "[100, 290)|[300, 500)"}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, Big) +TEST_P(SegmentBitmapFilterTest, Big) try { - runTestCase(TestCase{ - "d_tiny:[100, 500)|d_big:[250, 1000)|d_mem:[240, 290)", - 900, - "[0, 140)|[1150, 1200)|[440, 1150)", - "[100, 1000)"}); + runTestCaseGeneric( + TestCase{ + .seg_data = "d_tiny:[100, 500)|d_big:[250, 1000)|d_mem:[240, 290)", + .expected_size = 900, + .expected_row_id = "[0, 140)|[1150, 1200)|[440, 1150)", + .expected_handle = "[100, 1000)"}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, Stable1) +TEST_P(SegmentBitmapFilterTest, Stable1) try { - runTestCase(TestCase{"s:[0, 1024)", 1024, "[0, 1024)", "[0, 1024)"}); + runTestCaseGeneric( + TestCase{ + .seg_data = "s:[0, 1024)", + .expected_size = 1024, + .expected_row_id = "[0, 1024)", + .expected_handle = "[0, 1024)"}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, Stable2) +TEST_P(SegmentBitmapFilterTest, Stable2) try { - runTestCase(TestCase{"s:[0, 1024)|d_dr:[0, 1023)", 1, "[1023, 1024)", "[1023, 1024)"}); + runTestCaseGeneric( + TestCase{ + .seg_data = "s:[0, 1024)|d_dr:[0, 1023)", + .expected_size = 1, + .expected_row_id = "[1023, 1024)", + .expected_handle = "[1023, 1024)"}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, Stable3) +TEST_P(SegmentBitmapFilterTest, Stable3) try { - runTestCase(TestCase{ - "s:[0, 1024)|d_dr:[128, 256)|d_tiny_del:[300, 310)", - 886, - "[0, 128)|[256, 300)|[310, 1024)", - "[0, 128)|[256, 300)|[310, 1024)"}); + runTestCaseGeneric( + TestCase{ + .seg_data = "s:[0, 1024)|d_dr:[128, 256)|d_tiny_del:[300, 310)", + .expected_size = 886, + .expected_row_id = "[0, 128)|[256, 300)|[310, 1024)", + .expected_handle = "[0, 128)|[256, 300)|[310, 1024)"}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, Mix) +TEST_P(SegmentBitmapFilterTest, Mix) try { - runTestCase(TestCase{ - "s:[0, 1024)|d_dr:[128, 256)|d_tiny_del:[300, 310)|d_tiny:[200, 255)|d_mem:[298, 305)", - 946, - "[0, 128)|[1034, 1089)|[256, 298)|[1089, 1096)|[310, 1024)", - "[0, 128)|[200, 255)|[256, 305)|[310, 1024)"}); + runTestCaseGeneric( + TestCase{ + .seg_data = "s:[0, 1024)|d_dr:[128, 256)|d_tiny_del:[300, 310)|d_tiny:[200, 255)|d_mem:[298, 305)", + .expected_size = 946, + .expected_row_id = "[0, 128)|[1034, 1089)|[256, 298)|[1089, 1096)|[310, 1024)", + .expected_handle = "[0, 128)|[200, 255)|[256, 305)|[310, 1024)"}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, Ranges) +TEST_P(SegmentBitmapFilterTest, Ranges) try { - read_ranges.emplace_back(buildRowKeyRange(222, 244)); - read_ranges.emplace_back(buildRowKeyRange(300, 303)); - read_ranges.emplace_back(buildRowKeyRange(555, 666)); - runTestCase(TestCase{ - "s:[0, 1024)|d_dr:[128, 256)|d_tiny_del:[300, 310)|d_tiny:[200, 255)|d_mem:[298, 305)", - 136, - "[1056, 1078)|[1091, 1094)|[555, 666)", - "[222, 244)|[300, 303)|[555, 666)"}); + read_ranges.emplace_back(buildRowKeyRange(222, 244, is_common_handle)); + read_ranges.emplace_back(buildRowKeyRange(300, 303, is_common_handle)); + read_ranges.emplace_back(buildRowKeyRange(555, 666, is_common_handle)); + runTestCaseGeneric( + TestCase{ + .seg_data = "s:[0, 1024)|d_dr:[128, 256)|d_tiny_del:[300, 310)|d_tiny:[200, 255)|d_mem:[298, 305)", + .expected_size = 136, + .expected_row_id = "[1056, 1078)|[1091, 1094)|[555, 666)", + .expected_handle = "[222, 244)|[300, 303)|[555, 666)"}, + __LINE__); } CATCH -TEST_F(SegmentBitmapFilterTest, LogicalSplit) +TEST_P(SegmentBitmapFilterTest, LogicalSplit) try { - runTestCase(TestCase{ - "s:[0, 1024)|d_dr:[128, 256)|d_tiny_del:[300, 310)|d_tiny:[200, 255)|d_mem:[298, 305)", - 946, - "[0, 128)|[1034, 1089)|[256, 298)|[1089, 1096)|[310, 1024)", - "[0, 128)|[200, 255)|[256, 305)|[310, 1024)"}); + runTestCaseGeneric( + TestCase{ + .seg_data = "s:[0, 1024)|d_dr:[128, 256)|d_tiny_del:[300, 310)|d_tiny:[200, 255)|d_mem:[298, 305)", + .expected_size = 946, + .expected_row_id = "[0, 128)|[1034, 1089)|[256, 298)|[1089, 1096)|[310, 1024)", + .expected_handle = "[0, 128)|[200, 255)|[256, 305)|[310, 1024)"}, + __LINE__); auto new_seg_id = splitSegmentAt(SEG_ID, 512, Segment::SplitMode::Logical); ASSERT_TRUE(new_seg_id.has_value()); ASSERT_TRUE(areSegmentsSharingStable({SEG_ID, *new_seg_id})); - auto left_handle = getSegmentHandle(SEG_ID, {}); - const auto & left_h = toColumnVectorData(left_handle); - auto expected_left_handle = genSequence("[0, 128)|[200, 255)|[256, 305)|[310, 512)"); - ASSERT_TRUE(sequenceEqual(expected_left_handle, left_h)); + checkHandle(SEG_ID, "[0, 128)|[200, 255)|[256, 305)|[310, 512)", __LINE__); auto left_row_id = getSegmentRowId(SEG_ID, {}); const auto & left_r = toColumnVectorData(left_row_id); auto expected_left_row_id = genSequence("[0, 128)|[1034, 1089)|[256, 298)|[1089, 1096)|[310, 512)"); ASSERT_TRUE(sequenceEqual(expected_left_row_id, left_r)); - auto right_handle = getSegmentHandle(*new_seg_id, {}); - const auto & right_h = toColumnVectorData(right_handle); - auto expected_right_handle = genSequence("[512, 1024)"); - ASSERT_TRUE(sequenceEqual(expected_right_handle, right_h)); + checkHandle(*new_seg_id, "[512, 1024)", __LINE__); auto right_row_id = getSegmentRowId(*new_seg_id, {}); const auto & right_r = toColumnVectorData(right_row_id); @@ -370,9 +489,9 @@ try } CATCH -TEST_F(SegmentBitmapFilterTest, CleanStable) +TEST_P(SegmentBitmapFilterTest, CleanStable) { - writeSegment("d_mem:[0, 20000)|d_mem:[30000, 35000)"); + writeSegmentGeneric("d_mem:[0, 20000)|d_mem:[30000, 35000)"); mergeSegmentDelta(SEG_ID, true); auto [seg, snap] = getSegmentForRead(SEG_ID); ASSERT_EQ(seg->getDelta()->getRows(), 0); @@ -391,9 +510,9 @@ TEST_F(SegmentBitmapFilterTest, CleanStable) ASSERT_EQ(bitmap_filter->toDebugString(), expect_result); } -TEST_F(SegmentBitmapFilterTest, NotCleanStable) +TEST_P(SegmentBitmapFilterTest, NotCleanStable) { - writeSegment("d_mem:[0, 10000)|d_mem:[5000, 15000)"); + writeSegmentGeneric("d_mem:[0, 10000)|d_mem:[5000, 15000)"); mergeSegmentDelta(SEG_ID, true); auto [seg, snap] = getSegmentForRead(SEG_ID); ASSERT_EQ(seg->getDelta()->getRows(), 0); @@ -439,16 +558,16 @@ TEST_F(SegmentBitmapFilterTest, NotCleanStable) } } -TEST_F(SegmentBitmapFilterTest, StableRange) +TEST_P(SegmentBitmapFilterTest, StableRange) { - writeSegment("d_mem:[0, 50000)"); + writeSegmentGeneric("d_mem:[0, 50000)"); mergeSegmentDelta(SEG_ID, true); auto [seg, snap] = getSegmentForRead(SEG_ID); ASSERT_EQ(seg->getDelta()->getRows(), 0); ASSERT_EQ(seg->getDelta()->getDeletes(), 0); ASSERT_EQ(seg->getStable()->getRows(), 50000); - auto ranges = std::vector{buildRowKeyRange(10000, 50000)}; // [10000, 50000) + auto ranges = std::vector{buildRowKeyRange(10000, 50000, is_common_handle)}; // [10000, 50000) auto bitmap_filter = seg->buildBitmapFilterStableOnly( *dm_context, snap, @@ -464,10 +583,10 @@ TEST_F(SegmentBitmapFilterTest, StableRange) ASSERT_EQ(bitmap_filter->toDebugString(), expect_result); } -TEST_F(SegmentBitmapFilterTest, StableLogicalSplit) +TEST_P(SegmentBitmapFilterTest, StableLogicalSplit) try { - writeSegment("d_mem:[0, 50000)"); + writeSegmentGeneric("d_mem:[0, 50000)"); mergeSegmentDelta(SEG_ID, true); auto [seg, snap] = getSegmentForRead(SEG_ID); ASSERT_EQ(seg->getDelta()->getRows(), 0); @@ -479,20 +598,14 @@ try ASSERT_TRUE(new_seg_id.has_value()); ASSERT_TRUE(areSegmentsSharingStable({SEG_ID, *new_seg_id})); - auto left_handle = getSegmentHandle(SEG_ID, {}); - const auto & left_h = toColumnVectorData(left_handle); - auto expected_left_handle = genSequence("[0, 25000)"); - ASSERT_TRUE(sequenceEqual(expected_left_handle, left_h)); + checkHandle(SEG_ID, "[0, 25000)", __LINE__); auto left_row_id = getSegmentRowId(SEG_ID, {}); const auto & left_r = toColumnVectorData(left_row_id); auto expected_left_row_id = genSequence("[0, 25000)"); ASSERT_TRUE(sequenceEqual(expected_left_row_id, left_r)); - auto right_handle = getSegmentHandle(*new_seg_id, {}); - const auto & right_h = toColumnVectorData(right_handle); - auto expected_right_handle = genSequence("[25000, 50000)"); - ASSERT_TRUE(sequenceEqual(expected_right_handle, right_h)); + checkHandle(*new_seg_id, "[25000, 50000)", __LINE__); auto right_row_id = getSegmentRowId(*new_seg_id, {}); const auto & right_r = toColumnVectorData(right_row_id); @@ -501,17 +614,19 @@ try } CATCH -TEST_F(SegmentBitmapFilterTest, BigPart) +TEST_P(SegmentBitmapFilterTest, BigPart) try { // For ColumnFileBig, only packs that intersection with the rowkey range will be considered in BitmapFilter. // Packs in rowkey_range: [270, 280)|[280, 290)|[290, 300) - runTestCase(TestCase{ - /*seg_data*/ "d_big:[250, 1000):10", - /*expected_size*/ 20, - /*expected_row_id*/ "[5, 25)", - /*expected_handle*/ "[275, 295)", - /*rowkey_range*/ std::pair{275, 295}}); + runTestCaseGeneric( + TestCase{ + .seg_data = "d_big:[250, 1000):pack_size_10", + .expected_size = 20, + .expected_row_id = "[5, 25)", + .expected_handle = "[275, 295)", + .rowkey_range = std::tuple{275, 295, false}}, + __LINE__); auto [seg, snap] = getSegmentForRead(SEG_ID); auto bitmap_filter = seg->buildBitmapFilter( @@ -527,14 +642,16 @@ try } CATCH -TEST_F(SegmentBitmapFilterTest, StablePart) +TEST_P(SegmentBitmapFilterTest, StablePart) try { - runTestCase(TestCase{ - /*seg_data*/ "s:[250, 1000):10", - /*expected_size*/ 750, - /*expected_row_id*/ "[0, 750)", - /*expected_handle*/ "[250, 1000)"}); + runTestCaseGeneric( + TestCase{ + .seg_data = "s:[250, 1000):pack_size_10", + .expected_size = 750, + .expected_row_id = "[0, 750)", + .expected_handle = "[250, 1000)"}, + __LINE__); { auto [seg, snap] = getSegmentForRead(SEG_ID); @@ -542,7 +659,7 @@ try } // For Stable, all packs of DMFile will be considered in BitmapFilter. - setRowKeyRange(275, 295); // Shrinking range + setRowKeyRange(275, 295, false); // Shrinking range auto [seg, snap] = getSegmentForRead(SEG_ID); auto bitmap_filter = seg->buildBitmapFilter( *dm_context, @@ -565,7 +682,7 @@ try } CATCH -TEST_F(SegmentBitmapFilterTest, testSkipPackStableOnly) +TEST_P(SegmentBitmapFilterTest, testSkipPackStableOnly) { std::string expect_result; expect_result.append(std::string(200, '0')); @@ -581,14 +698,14 @@ TEST_F(SegmentBitmapFilterTest, testSkipPackStableOnly) reloadDMContext(); version = 0; - writeSegment("d_mem:[0, 1000)|d_mem:[500, 1500)|d_mem:[1500, 2000)"); + writeSegmentGeneric("d_mem:[0, 1000)|d_mem:[500, 1500)|d_mem:[1500, 2000)"); mergeSegmentDelta(SEG_ID, true); auto [seg, snap] = getSegmentForRead(SEG_ID); ASSERT_EQ(seg->getDelta()->getRows(), 0); ASSERT_EQ(seg->getDelta()->getDeletes(), 0); ASSERT_EQ(seg->getStable()->getRows(), 2500); - auto ranges = std::vector{buildRowKeyRange(200, 2000)}; + auto ranges = std::vector{buildRowKeyRange(200, 2000, is_common_handle)}; auto pack_filter_results = loadPackFilterResults(snap, ranges); if (pack_rows == 1) @@ -630,7 +747,7 @@ TEST_F(SegmentBitmapFilterTest, testSkipPackStableOnly) } } -TEST_F(SegmentBitmapFilterTest, testSkipPackNormal) +TEST_P(SegmentBitmapFilterTest, testSkipPackNormal) { std::string expect_result; expect_result.append(std::string(50, '0')); @@ -656,15 +773,15 @@ TEST_F(SegmentBitmapFilterTest, testSkipPackNormal) reloadDMContext(); version = 0; - writeSegment("d_mem:[0, 1000)|d_mem:[500, 1500)|d_mem:[1500, 2000)"); + writeSegmentGeneric("d_mem:[0, 1000)|d_mem:[500, 1500)|d_mem:[1500, 2000)"); mergeSegmentDelta(SEG_ID, true); - writeSegment("d_tiny:[99, 100)|d_dr:[355, 370)|d_dr:[409, 481)|d_mem:[200, 201)|d_mem:[301, 315)"); + writeSegmentGeneric("d_tiny:[99, 100)|d_dr:[355, 370)|d_dr:[409, 481)|d_mem:[200, 201)|d_mem:[301, 315)"); auto [seg, snap] = getSegmentForRead(SEG_ID); ASSERT_EQ(seg->getDelta()->getRows(), 16); ASSERT_EQ(seg->getDelta()->getDeletes(), 2); ASSERT_EQ(seg->getStable()->getRows(), 2500); - auto ranges = std::vector{buildRowKeyRange(50, 2000)}; + auto ranges = std::vector{buildRowKeyRange(50, 2000, is_common_handle)}; auto pack_filter_results = loadPackFilterResults(snap, ranges); UInt64 start_ts = 6; if (pack_rows == 10) @@ -754,5 +871,4 @@ TEST_F(SegmentBitmapFilterTest, testSkipPackNormal) deleteRangeSegment(SEG_ID); } } - } // namespace DB::DM::tests diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.cpp index 641584987f4..29e663b1812 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.cpp @@ -59,6 +59,26 @@ extern DMFilePtr writeIntoNewDMFile( namespace DB::DM::tests { +namespace +{ + +// a + b +bool isSumOverflow(Int64 a, Int64 b) +{ + return (b > 0 && a > std::numeric_limits::max() - b) || (b < 0 && a < std::numeric_limits::min() - b); +} + +// a - b +auto isDiffOverflow(Int64 a, Int64 b) +{ + assert(a > b); + if (b < 0) + return a > std::numeric_limits::max() + b; + else + return false; +} +} // namespace + void SegmentTestBasic::reloadWithOptions(SegmentTestOptions config) { { @@ -302,9 +322,24 @@ void SegmentTestBasic::mergeSegment(const PageIdU64s & segments_id, bool check_r operation_statistics["mergeTwo"]++; } -void SegmentTestBasic::mergeSegmentDelta(PageIdU64 segment_id, bool check_rows) +void SegmentTestBasic::mergeSegmentDelta(PageIdU64 segment_id, bool check_rows, std::optional pack_size) { - LOG_INFO(logger_op, "mergeSegmentDelta, segment_id={}", segment_id); + LOG_INFO(logger_op, "mergeSegmentDelta, segment_id={}, pack_size={}", segment_id, pack_size); + + const size_t initial_pack_rows = db_context->getSettingsRef().dt_segment_stable_pack_rows; + if (pack_size) + { + db_context->getSettingsRef().dt_segment_stable_pack_rows = *pack_size; + reloadDMContext(); + ASSERT_EQ(dm_context->stable_pack_rows, *pack_size); + } + SCOPE_EXIT({ + if (initial_pack_rows != db_context->getSettingsRef().dt_segment_stable_pack_rows) + { + db_context->getSettingsRef().dt_segment_stable_pack_rows = initial_pack_rows; + reloadDMContext(); + } + }); RUNTIME_CHECK(segments.find(segment_id) != segments.end()); auto segment = segments[segment_id]; @@ -330,6 +365,18 @@ void SegmentTestBasic::flushSegmentCache(PageIdU64 segment_id) operation_statistics["flush"]++; } +void SegmentTestBasic::compactSegmentDelta(PageIdU64 segment_id) +{ + LOG_INFO(logger_op, "compactSegmentDelta, segment_id={}", segment_id); + + RUNTIME_CHECK(segments.find(segment_id) != segments.end()); + auto segment = segments[segment_id]; + size_t segment_row_num = getSegmentRowNum(segment_id); + segment->compactDelta(*dm_context); + EXPECT_EQ(getSegmentRowNum(segment_id), segment_row_num); + operation_statistics["compact"]++; +} + std::pair SegmentTestBasic::getSegmentKeyRange(PageIdU64 segment_id) const { RUNTIME_CHECK(segments.find(segment_id) != segments.end()); @@ -367,29 +414,43 @@ std::pair SegmentTestBasic::getSegmentKeyRange(PageIdU64 segment_i return {start_key, end_key}; } -Block SegmentTestBasic::prepareWriteBlockImpl(Int64 start_key, Int64 end_key, bool is_deleted) +Block SegmentTestBasic::prepareWriteBlockImpl( + Int64 start_key, + Int64 end_key, + bool is_deleted, + bool including_right_boundary, + std::optional ts) { - RUNTIME_CHECK(start_key <= end_key); - if (end_key == start_key) + RUNTIME_CHECK(start_key <= end_key, start_key, end_key); + if (end_key == start_key && !including_right_boundary) return Block{}; - version++; + + UInt64 v = ts.value_or(version + 1); + version = std::max(v, version + 1); // Increase version return DMTestEnv::prepareSimpleWriteBlock( start_key, // end_key, false, - version, + v, DMTestEnv::pk_name, MutSup::extra_handle_id, options.is_common_handle ? MutSup::getExtraHandleColumnStringType() : MutSup::getExtraHandleColumnIntType(), options.is_common_handle, 1, true, - is_deleted); + is_deleted, + /*with_null_uint64*/ false, + including_right_boundary); } -Block SegmentTestBasic::prepareWriteBlock(Int64 start_key, Int64 end_key, bool is_deleted) +Block SegmentTestBasic::prepareWriteBlock( + Int64 start_key, + Int64 end_key, + bool is_deleted, + bool including_right_boundary, + std::optional ts) { - return prepareWriteBlockImpl(start_key, end_key, is_deleted); + return prepareWriteBlockImpl(start_key, end_key, is_deleted, including_right_boundary, ts); } Block sortvstackBlocks(std::vector && blocks) @@ -406,28 +467,48 @@ Block sortvstackBlocks(std::vector && blocks) Block SegmentTestBasic::prepareWriteBlockInSegmentRange( PageIdU64 segment_id, - UInt64 total_write_rows, + Int64 total_write_rows, std::optional write_start_key, - bool is_deleted) + bool is_deleted, + std::optional ts) { - RUNTIME_CHECK(total_write_rows < std::numeric_limits::max()); + RUNTIME_CHECK(0 < total_write_rows, total_write_rows); - RUNTIME_CHECK(segments.find(segment_id) != segments.end()); - auto [segment_start_key, segment_end_key] = getSegmentKeyRange(segment_id); - auto segment_max_rows = static_cast(segment_end_key - segment_start_key); - - if (segment_max_rows == 0) - return {}; + // For example, write_start_key is int64_max and total_write_rows is 1, + // We will generate block with one row with int64_max. + auto is_including_right_boundary = [](Int64 write_start_key, Int64 total_write_rows) { + return std::numeric_limits::max() - total_write_rows + 1 == write_start_key; + }; + auto seg = segments.find(segment_id); + RUNTIME_CHECK(seg != segments.end()); + auto segment_range = seg->second->getRowKeyRange(); + auto [segment_start_key, segment_end_key] = getSegmentKeyRange(segment_id); + bool including_right_boundary = false; if (write_start_key.has_value()) { - // When write start key is specified, the caller must know exactly the segment range. - RUNTIME_CHECK(*write_start_key >= segment_start_key); - RUNTIME_CHECK(static_cast(segment_end_key - *write_start_key) > 0); - } + RUNTIME_CHECK_MSG( + segment_start_key <= *write_start_key + && (*write_start_key < segment_end_key || segment_range.isEndInfinite()), + "write_start_key={} segment_range={}", + *write_start_key, + segment_range.toDebugString()); - if (!write_start_key.has_value()) + including_right_boundary = is_including_right_boundary(*write_start_key, total_write_rows); + RUNTIME_CHECK( + including_right_boundary || !isSumOverflow(*write_start_key, total_write_rows), + *write_start_key, + total_write_rows); + } + else { + auto segment_max_rows = isDiffOverflow(segment_end_key, segment_start_key) + ? std::numeric_limits::max() // UInt64 is more accurate, but Int64 is enough and for simplicity. + : segment_end_key - segment_start_key; + + if (segment_max_rows == 0) + return {}; + // When write start key is unspecified, we will: // A. If the segment is large enough, we randomly pick a write start key in the range. // B. If the segment is small, we write from the beginning. @@ -460,10 +541,10 @@ Block SegmentTestBasic::prepareWriteBlockInSegmentRange( RUNTIME_CHECK(write_rows_this_round > 0); Int64 write_end_key_this_round = *write_start_key + static_cast(write_rows_this_round); RUNTIME_CHECK(write_end_key_this_round <= segment_end_key); - - Block block = prepareWriteBlock(*write_start_key, write_end_key_this_round, is_deleted); + Block block + = prepareWriteBlock(*write_start_key, write_end_key_this_round, is_deleted, including_right_boundary, ts); blocks.emplace_back(block); - remaining_rows -= write_rows_this_round; + remaining_rows -= write_rows_this_round + including_right_boundary; LOG_DEBUG( logger, @@ -478,7 +559,52 @@ Block SegmentTestBasic::prepareWriteBlockInSegmentRange( return sortvstackBlocks(std::move(blocks)); } -void SegmentTestBasic::writeSegment(PageIdU64 segment_id, UInt64 write_rows, std::optional start_at) +void SegmentTestBasic::writeToCache( + PageIdU64 segment_id, + UInt64 write_rows, + Int64 start_at, + bool shuffle, + std::optional ts) +{ + LOG_INFO(logger_op, "writeToCache, segment_id={} write_rows={}", segment_id, write_rows); + + if (write_rows == 0) + return; + + RUNTIME_CHECK(segments.find(segment_id) != segments.end()); + auto segment = segments[segment_id]; + size_t segment_row_num = getSegmentRowNumWithoutMVCC(segment_id); + auto [start_key, end_key] = getSegmentKeyRange(segment_id); + LOG_DEBUG( + logger, + "write to segment, segment={} segment_rows={} start_key={} end_key={}", + segment->info(), + segment_row_num, + start_key, + end_key); + + auto block = prepareWriteBlockInSegmentRange(segment_id, write_rows, start_at, /* is_deleted */ false, ts); + + if (shuffle) + { + IColumn::Permutation perm(block.rows()); + std::iota(perm.begin(), perm.end(), 0); + std::shuffle(perm.begin(), perm.end(), std::mt19937(std::random_device()())); + for (auto & column : block) + column.column = column.column->permute(perm, 0); + } + segment->writeToCache(*dm_context, block, 0, block.rows()); + + EXPECT_EQ(getSegmentRowNumWithoutMVCC(segment_id), segment_row_num + write_rows); + operation_statistics["write"]++; +} + +void SegmentTestBasic::writeSegment( + PageIdU64 segment_id, + UInt64 write_rows, + std::optional start_at, + bool shuffle, + std::optional ts) { LOG_INFO(logger_op, "writeSegment, segment_id={} write_rows={}", segment_id, write_rows); @@ -497,7 +623,16 @@ void SegmentTestBasic::writeSegment(PageIdU64 segment_id, UInt64 write_rows, std start_key, end_key); - auto block = prepareWriteBlockInSegmentRange(segment_id, write_rows, start_at, /* is_deleted */ false); + auto block = prepareWriteBlockInSegmentRange(segment_id, write_rows, start_at, /* is_deleted */ false, ts); + + if (shuffle) + { + IColumn::Permutation perm(block.rows()); + std::iota(perm.begin(), perm.end(), 0); + std::shuffle(perm.begin(), perm.end(), std::mt19937(std::random_device()())); + for (auto & column : block) + column.column = column.column->permute(perm, 0); + } segment->write(*dm_context, block, false); EXPECT_EQ(getSegmentRowNumWithoutMVCC(segment_id), segment_row_num + write_rows); @@ -521,7 +656,10 @@ BlockInputStreamPtr SegmentTestBasic::getIngestDTFileInputStream( rows_per_block = std::min(rows_per_block, write_rows - written); std::optional start; if (start_at) + { + RUNTIME_CHECK(!isSumOverflow(*start_at, written), *start_at, written); start.emplace(*start_at + written); + } if (check_range) { @@ -530,9 +668,13 @@ BlockInputStreamPtr SegmentTestBasic::getIngestDTFileInputStream( } else { - auto start_key = start ? *start : 0; - auto end_key = start_key + rows_per_block; - auto block = prepareWriteBlock(start_key, end_key); + Int64 start_key = start.value_or(0); + bool overflow = std::numeric_limits::max() - static_cast(rows_per_block) < start_key; + Int64 end_key = overflow ? std::numeric_limits::max() : start_key + rows_per_block; + // if overflow, write [start_key, end_key] + // if not overflow, write [start_key, end_key) + rows_per_block += overflow; + auto block = prepareWriteBlock(start_key, end_key, /*is_deleted*/ false, overflow); streams.push_back(std::make_shared(std::move(block))); } } @@ -1053,10 +1195,33 @@ Block mergeSegmentRowIds(std::vector && blocks) return accumulated_block; } -RowKeyRange SegmentTestBasic::buildRowKeyRange(Int64 begin, Int64 end) +RowKeyRange SegmentTestBasic::buildRowKeyRange( + Int64 begin, + Int64 end, + bool is_common_handle, + bool including_right_boundary) { - HandleRange range(begin, end); - return RowKeyRange::fromHandleRange(range); + // `including_right_boundary` is for creating range like [begin, std::numeric_limits::max()) or [begin, std::numeric_limits::max()] + if (including_right_boundary) + RUNTIME_CHECK(end == std::numeric_limits::max()); + + if (is_common_handle) + { + auto create_rowkey_value = [](Int64 v) { + WriteBufferFromOwnString ss; + DB::EncodeUInt(static_cast(TiDB::CodecFlagInt), ss); + DB::EncodeInt64(v, ss); + return std::make_shared(ss.releaseStr()); + }; + auto left = RowKeyValue{is_common_handle, create_rowkey_value(begin)}; + auto right = including_right_boundary ? RowKeyValue::COMMON_HANDLE_MAX_KEY + : RowKeyValue{is_common_handle, create_rowkey_value(end)}; + return RowKeyRange{left, right, is_common_handle, 1}; + } + + auto left = RowKeyValue::fromHandle(begin); + auto right = including_right_boundary ? RowKeyValue::INT_HANDLE_MAX_KEY : RowKeyValue::fromHandle(end); + return RowKeyRange{left, right, is_common_handle, 1}; } std::pair SegmentTestBasic::getSegmentForRead(PageIdU64 segment_id) @@ -1126,9 +1291,14 @@ ColumnPtr SegmentTestBasic::getSegmentHandle(PageIdU64 segment_id, const RowKeyR } } -void SegmentTestBasic::writeSegmentWithDeleteRange(PageIdU64 segment_id, Int64 begin, Int64 end) +void SegmentTestBasic::writeSegmentWithDeleteRange( + PageIdU64 segment_id, + Int64 begin, + Int64 end, + bool is_common_handle, + bool including_right_boundary) { - auto range = buildRowKeyRange(begin, end); + auto range = buildRowKeyRange(begin, end, is_common_handle, including_right_boundary); RUNTIME_CHECK(segments.find(segment_id) != segments.end()); auto segment = segments[segment_id]; RUNTIME_CHECK(segment->write(*dm_context, range)); diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.h b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.h index 0ac5d82cb76..0eb91754d38 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.h +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_test_basic.h @@ -65,14 +65,24 @@ class SegmentTestBasic : public DB::base::TiFlashStorageTestBasic Segment::SplitMode split_mode = Segment::SplitMode::Auto, bool check_rows = true); void mergeSegment(const std::vector & segments, bool check_rows = true); - void mergeSegmentDelta(PageIdU64 segment_id, bool check_rows = true); + void mergeSegmentDelta( + PageIdU64 segment_id, + bool check_rows = true, + std::optional pack_size = std::nullopt); void flushSegmentCache(PageIdU64 segment_id); + void compactSegmentDelta(PageIdU64 segment_id); /** * When begin_key is specified, new rows will be written from specified key. Otherwise, new rows may be * written randomly in the segment range. */ - void writeSegment(PageIdU64 segment_id, UInt64 write_rows = 100, std::optional start_at = std::nullopt); + void writeToCache(PageIdU64 segment_id, UInt64 write_rows, Int64 start_at, bool shuffle, std::optional ts); + void writeSegment( + PageIdU64 segment_id, + UInt64 write_rows = 100, + std::optional start_at = std::nullopt, + bool shuffle = false, + std::optional ts = std::nullopt); void ingestDTFileIntoDelta( PageIdU64 segment_id, UInt64 write_rows = 100, @@ -113,12 +123,18 @@ class SegmentTestBasic : public DB::base::TiFlashStorageTestBasic */ bool ensureSegmentDeltaLocalIndex(PageIdU64 segment_id, const LocalIndexInfosPtr & local_index_infos); - Block prepareWriteBlock(Int64 start_key, Int64 end_key, bool is_deleted = false); + Block prepareWriteBlock( + Int64 start_key, + Int64 end_key, + bool is_deleted = false, + bool including_right_boundary = false, + std::optional ts = std::nullopt); Block prepareWriteBlockInSegmentRange( PageIdU64 segment_id, - UInt64 total_write_rows, + Int64 total_write_rows, std::optional write_start_key = std::nullopt, - bool is_deleted = false); + bool is_deleted = false, + std::optional ts = std::nullopt); size_t getSegmentRowNumWithoutMVCC(PageIdU64 segment_id); size_t getSegmentRowNum(PageIdU64 segment_id); @@ -138,9 +154,18 @@ class SegmentTestBasic : public DB::base::TiFlashStorageTestBasic std::vector readSegment(PageIdU64 segment_id, bool need_row_id, const RowKeyRanges & ranges); ColumnPtr getSegmentRowId(PageIdU64 segment_id, const RowKeyRanges & ranges); ColumnPtr getSegmentHandle(PageIdU64 segment_id, const RowKeyRanges & ranges); - void writeSegmentWithDeleteRange(PageIdU64 segment_id, Int64 begin, Int64 end); + void writeSegmentWithDeleteRange( + PageIdU64 segment_id, + Int64 begin, + Int64 end, + bool is_common_handle, + bool including_right_boundary); RowKeyValue buildRowKeyValue(Int64 key); - static RowKeyRange buildRowKeyRange(Int64 begin, Int64 end); + static RowKeyRange buildRowKeyRange( + Int64 begin, + Int64 end, + bool is_common_handle, + bool including_right_boundary = false); size_t getPageNumAfterGC(StorageType type, NamespaceID ns_id) const; @@ -161,7 +186,12 @@ class SegmentTestBasic : public DB::base::TiFlashStorageTestBasic const ColumnDefinesPtr & tableColumns() const { return table_columns; } - virtual Block prepareWriteBlockImpl(Int64 start_key, Int64 end_key, bool is_deleted); + virtual Block prepareWriteBlockImpl( + Int64 start_key, + Int64 end_key, + bool is_deleted, + bool including_right_boundary, + std::optional ts); virtual void prepareColumns(const ColumnDefinesPtr &) {} diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_util.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_util.cpp index 7790c81dc4f..9aba27700ac 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_util.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_util.cpp @@ -21,58 +21,103 @@ namespace DB::DM::tests namespace { -// "[a, b)" => std::pair{a, b} +Strings splitAndTrim(std::string_view s, std::string_view delimiter, std::optional expected_size = std::nullopt) +{ + Strings results; + boost::split(results, s, boost::is_any_of(delimiter)); + if (expected_size) + RUNTIME_CHECK(results.size() == *expected_size, s, delimiter, expected_size); + else + RUNTIME_CHECK(!results.empty(), s, delimiter); + for (auto & r : results) + boost::trim(r); + return results; +} + template -std::pair parseRange(String & str_range) +SegDataRange parseRange(std::string_view s) { - boost::algorithm::trim(str_range); - RUNTIME_CHECK(str_range.front() == '[' && str_range.back() == ')', str_range); - std::vector values; - str_range = str_range.substr(1, str_range.size() - 2); - boost::split(values, str_range, boost::is_any_of(",")); - RUNTIME_CHECK(values.size() == 2, str_range); - return {static_cast(std::stol(values[0])), static_cast(std::stol(values[1]))}; + auto str_range = boost::trim_copy(s); + RUNTIME_CHECK(str_range.front() == '[' && (str_range.back() == ')' || str_range.back() == ']'), str_range); + auto values = splitAndTrim(str_range.substr(1, str_range.size() - 2), ",", 2); + return SegDataRange{ + .left = static_cast(std::stol(values[0])), + .right = static_cast(std::stol(values[1])), + .including_right_boundary = str_range.back() == ']'}; } -// "[a, b)|[c, d)" => [std::pair{a, b}, std::pair{c, d}] +// "[a, b)|[c, d]" => [{a, b, false}, {c, d, true}] template -std::vector> parseRanges(std::string_view str_ranges) +std::vector> parseRanges(std::string_view s) { - std::vector ranges; - boost::split(ranges, str_ranges, boost::is_any_of("|")); - RUNTIME_CHECK(!ranges.empty(), str_ranges); - std::vector> vector_ranges; + auto str_range = boost::trim_copy(s); + RUNTIME_CHECK(str_range.front() == '[' && (str_range.back() == ')' || str_range.back() == ']'), str_range); + auto ranges = splitAndTrim(str_range, "|"); + std::vector> vector_ranges; vector_ranges.reserve(ranges.size()); for (auto & r : ranges) - { vector_ranges.emplace_back(parseRange(r)); - } return vector_ranges; } -// "type:[a, b)" => SegDataUnit -SegDataUnit parseSegDataUnit(String & s) +const std::unordered_set segment_commands = {"flush_cache", "compact_delta", "merge_delta"}; +const std::unordered_set delta_small_data_types = {"d_mem", "d_mem_del", "d_tiny", "d_tiny_del"}; +const std::unordered_set segment_data_types + = {"d_mem", "d_mem_del", "d_tiny", "d_tiny_del", "d_dr", "s"}; + +void parseSegUnitAttr(std::string_view attr, SegDataUnit & unit) { - boost::algorithm::trim(s); - std::vector values; - boost::split(values, s, boost::is_any_of(":")); - if (values.size() == 2) + // Pack size for DMFile + static const std::string_view attr_pack_size_prefix{"pack_size_"}; + // Shuffle data ColumnFileTiny and ColumnFileMemory + static const std::string_view attr_shuffle{"shuffle"}; + // Timestamp for generated data + static const std::string_view attr_timestamp_prefix{"ts_"}; + + if (attr.starts_with(attr_pack_size_prefix)) + { + RUNTIME_CHECK(unit.type == "d_big" || unit.type == "s" || unit.type == "merge_delta", attr, unit.type); + unit.pack_size = std::stoul(String(attr.substr(attr_pack_size_prefix.size()))); + return; + } + + if (attr == attr_shuffle) { - return SegDataUnit{ - .type = boost::algorithm::trim_copy(values[0]), - .range = parseRange(values[1]), - }; + RUNTIME_CHECK(delta_small_data_types.contains(unit.type), attr, unit.type, delta_small_data_types); + unit.shuffle = true; + return; } - else if (values.size() == 3) + + if (attr.starts_with(attr_timestamp_prefix)) { - RUNTIME_CHECK(values[0] == "d_big" || values[0] == "s", s); - return SegDataUnit{ - .type = boost::algorithm::trim_copy(values[0]), - .range = parseRange(values[1]), - .pack_size = std::stoul(values[2]), - }; + RUNTIME_CHECK( + delta_small_data_types.contains(unit.type) || unit.type == "s", + attr, + unit.type, + delta_small_data_types); + unit.ts = std::stoul(String(attr.substr(attr_timestamp_prefix.size()))); + return; } - RUNTIME_CHECK_MSG(false, "parseSegDataUnit failed: {}", s); + + RUNTIME_CHECK_MSG(false, "{} is unsupported", attr); +} + +// data_type:[left, right):attr1:attr2 +// cmd_type:attr1:attr2 +SegDataUnit parseSegDataUnit(std::string_view s) +{ + auto s_trim = boost::trim_copy(s); + auto values = splitAndTrim(s_trim, ":"); + size_t i = 0; + SegDataUnit unit{.type = values[i++]}; + if (!segment_commands.contains(unit.type)) + { + RUNTIME_CHECK(values.size() >= i, s, values); + unit.range = parseRange(values[i++]); + } + for (; i < values.size(); i++) + parseSegUnitAttr(values[i], unit); + return unit; } void check(const std::vector & seg_data_units) @@ -83,6 +128,9 @@ void check(const std::vector & seg_data_units) for (size_t i = 0; i < seg_data_units.size(); i++) { const auto & type = seg_data_units[i].type; + if (segment_commands.contains(type)) + continue; + if (type == "s") { stable_units.emplace_back(i); @@ -91,31 +139,29 @@ void check(const std::vector & seg_data_units) { mem_units.emplace_back(i); } - auto [begin, end] = seg_data_units[i].range; - RUNTIME_CHECK(begin < end, begin, end); + auto [begin, end, including_right_boundary] = seg_data_units[i].range; + RUNTIME_CHECK(end - begin + including_right_boundary > 0, begin, end, including_right_boundary); } + // If stable exists, it should be the first one. RUNTIME_CHECK(stable_units.empty() || (stable_units.size() == 1 && stable_units[0] == 0)); - std::vector expected_mem_units(mem_units.size()); - std::iota(expected_mem_units.begin(), expected_mem_units.end(), seg_data_units.size() - mem_units.size()); - RUNTIME_CHECK(mem_units == expected_mem_units, expected_mem_units, mem_units); } template -std::vector genSequence(T begin, T end) +std::vector genSequence(T begin, T end, bool including_right_boundary) { - auto size = end - begin; + auto size = end - begin + including_right_boundary; std::vector v(size); std::iota(v.begin(), v.end(), begin); return v; } template -std::vector genSequence(const std::vector> & ranges) +std::vector genSequence(const std::vector> & ranges) { std::vector res; - for (auto [begin, end] : ranges) + for (auto [begin, end, including_right_boundary] : ranges) { - auto v = genSequence(begin, end); + auto v = genSequence(begin, end, including_right_boundary); res.insert(res.end(), v.begin(), v.end()); } return res; @@ -124,12 +170,10 @@ std::vector genSequence(const std::vector> & ranges) std::vector parseSegData(std::string_view seg_data) { - std::vector str_seg_data_units; - boost::split(str_seg_data_units, seg_data, boost::is_any_of("|")); - RUNTIME_CHECK(!str_seg_data_units.empty(), seg_data); + auto str_seg_data_units = splitAndTrim(seg_data, "|"); std::vector seg_data_units; seg_data_units.reserve(str_seg_data_units.size()); - for (auto & s : str_seg_data_units) + for (const auto & s : str_seg_data_units) { seg_data_units.emplace_back(parseSegDataUnit(s)); } diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_util.h b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_util.h index ef1ffffb18a..3f90a9cae92 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_segment_util.h +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_segment_util.h @@ -14,17 +14,34 @@ #pragma once +#include #include #include #include namespace DB::DM::tests { + +// If including_right_boundary is false, it means [left, right). +// If including_right_boundary is true, it means [left, right]. +// `including_right_boundary` is required if we want to generate data with std::numeric_limits::max(). +// Theoretically, we could enforce the use of closed intervals, thereby eliminating the need for the parameter 'including_right_boundary'. +// However, a multitude of existing tests are predicated on the assumption that the interval is left-closed and right-open. +template +struct SegDataRange +{ + T left; + T right; + bool including_right_boundary; +}; + struct SegDataUnit { String type; - std::pair range; // Data range + SegDataRange range; std::optional pack_size; // For DMFile + bool shuffle = false; // For ColumnFileTiny and ColumnFileMemory + std::optional ts; }; std::vector parseSegData(std::string_view seg_data); @@ -32,6 +49,22 @@ std::vector parseSegData(std::string_view seg_data); template std::vector genSequence(std::string_view str_ranges); +template +std::vector genHandleSequence(std::string_view str_ranges) +{ + auto v = genSequence(str_ranges); + if constexpr (std::is_same_v) + return v; + else + { + static_assert(std::is_same_v); + std::vector res(v.size()); + for (size_t i = 0; i < v.size(); i++) + res[i] = genMockCommonHandle(v[i], 1); + return res; + } +} + template ::testing::AssertionResult sequenceEqual(const E & expected, const A & actual) { diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_skippable_block_input_stream.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_skippable_block_input_stream.cpp index e4ebbf2f6a0..ce47ee3807e 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_skippable_block_input_stream.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_skippable_block_input_stream.cpp @@ -269,37 +269,42 @@ class SkippableBlockInputStreamTest : public SegmentTestBasic void writeSegment(const SegDataUnit & unit) { const auto & type = unit.type; - auto [begin, end] = unit.range; - + const auto [begin, end, including_right_boundary] = unit.range; + const auto write_count = end - begin + including_right_boundary; if (type == "d_mem") { - SegmentTestBasic::writeSegment(SEG_ID, end - begin, begin); + SegmentTestBasic::writeSegment(SEG_ID, write_count, begin); } else if (type == "d_mem_del") { - SegmentTestBasic::writeSegmentWithDeletedPack(SEG_ID, end - begin, begin); + SegmentTestBasic::writeSegmentWithDeletedPack(SEG_ID, write_count, begin); } else if (type == "d_tiny") { - SegmentTestBasic::writeSegment(SEG_ID, end - begin, begin); + SegmentTestBasic::writeSegment(SEG_ID, write_count, begin); SegmentTestBasic::flushSegmentCache(SEG_ID); } else if (type == "d_tiny_del") { - SegmentTestBasic::writeSegmentWithDeletedPack(SEG_ID, end - begin, begin); + SegmentTestBasic::writeSegmentWithDeletedPack(SEG_ID, write_count, begin); SegmentTestBasic::flushSegmentCache(SEG_ID); } else if (type == "d_big") { - SegmentTestBasic::ingestDTFileIntoDelta(SEG_ID, end - begin, begin, false); + SegmentTestBasic::ingestDTFileIntoDelta(SEG_ID, write_count, begin, false); } else if (type == "d_dr") { - SegmentTestBasic::writeSegmentWithDeleteRange(SEG_ID, begin, end); + SegmentTestBasic::writeSegmentWithDeleteRange( + SEG_ID, + begin, + end, + /*is_common_handle*/ false, + including_right_boundary); } else if (type == "s") { - SegmentTestBasic::writeSegment(SEG_ID, end - begin, begin); + SegmentTestBasic::writeSegment(SEG_ID, write_count, begin); SegmentTestBasic::mergeSegmentDelta(SEG_ID); } else diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_sst_files_stream.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_sst_files_stream.cpp index af787a1c5f8..6e25e67af85 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_sst_files_stream.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_sst_files_stream.cpp @@ -14,13 +14,13 @@ #include #include +#include #include #include #include #include #include #include -#include #include #include #include @@ -41,7 +41,7 @@ class SSTFilesToDTFilesOutputStreamTest : public DB::base::TiFlashStorageTestBas public: void SetUp() override { - mock_region = makeRegion(1, RecordKVFormat::genKey(1, 0), RecordKVFormat::genKey(1, 1000)); + mock_region = RegionBench::makeRegionForTable(1, table_id, 0, 1000); TiFlashStorageTestBasic::SetUp(); setupStorage(); @@ -50,13 +50,13 @@ class SSTFilesToDTFilesOutputStreamTest : public DB::base::TiFlashStorageTestBas void TearDown() override { storage->drop(); - db_context->getTMTContext().getStorages().remove(NullspaceID, /* table id */ 100); + db_context->getTMTContext().getStorages().remove(NullspaceID, table_id); } void setupStorage() { auto columns = DM::tests::DMTestEnv::getDefaultTableColumns(pk_type); - auto table_info = DM::tests::DMTestEnv::getMinimalTableInfo(/* table id */ 100, pk_type); + auto table_info = DM::tests::DMTestEnv::getMinimalTableInfo(table_id, pk_type); auto astptr = DM::tests::DMTestEnv::getPrimaryKeyExpr("test_table", pk_type); storage = StorageDeltaMerge::create( @@ -103,6 +103,7 @@ class SSTFilesToDTFilesOutputStreamTest : public DB::base::TiFlashStorageTestBas } protected: + TableID table_id = 100; StorageDeltaMergePtr storage; RegionPtr mock_region; DMTestEnv::PkType pk_type = DMTestEnv::PkType::HiddenTiDBRowID; diff --git a/dbms/src/Storages/DeltaMerge/workload/Options.cpp b/dbms/src/Storages/DeltaMerge/workload/Options.cpp index 23a4c0e3a06..d4cd9d75d09 100644 --- a/dbms/src/Storages/DeltaMerge/workload/Options.cpp +++ b/dbms/src/Storages/DeltaMerge/workload/Options.cpp @@ -15,10 +15,9 @@ #include #include #include +#include #include -#include - namespace DB::DM::tests { std::string WorkloadOptions::toString(std::string seperator) const diff --git a/dbms/src/Storages/IManageableStorage.h b/dbms/src/Storages/IManageableStorage.h index 9df7f17242f..db399a157af 100644 --- a/dbms/src/Storages/IManageableStorage.h +++ b/dbms/src/Storages/IManageableStorage.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -34,7 +35,6 @@ namespace DB { struct SchemaNameMapper; class ASTStorage; -class Region; namespace DM { diff --git a/dbms/src/Storages/IStorage.h b/dbms/src/Storages/IStorage.h index bc75af28843..d1af315a918 100644 --- a/dbms/src/Storages/IStorage.h +++ b/dbms/src/Storages/IStorage.h @@ -23,8 +23,6 @@ #include #include -#include -#include namespace DB @@ -72,25 +70,16 @@ class IStorage */ virtual std::string getTableName() const = 0; - /** Returns true if the storage receives data from a remote server or servers. */ - virtual bool isRemote() const { return false; } - /** Returns true if the storage supports queries with the SAMPLE section. */ virtual bool supportsSampling() const { return false; } /** Returns true if the storage supports queries with the FINAL section. */ virtual bool supportsFinal() const { return false; } - /** Returns true if the storage supports queries with the PREWHERE section. */ - virtual bool supportsPrewhere() const { return false; } - - /** Returns true if the storage replicates SELECT, INSERT and ALTER commands among replicas. */ - virtual bool supportsReplication() const { return false; } - /** Returns true if the storage supports UPSERT, DELETE or UPDATE. */ virtual bool supportsModification() const { return false; } - /// Lock table for share. This lock must be acuqired if you want to be sure, + /// Lock table for share. This lock must be acquired if you want to be sure, /// that table will be not dropped while you holding this lock. It's used in /// variety of cases starting from SELECT queries to background merges in /// MergeTree. @@ -98,7 +87,7 @@ class IStorage const String & query_id, const std::chrono::milliseconds & acquire_timeout = std::chrono::milliseconds(0)); - /// Lock table for alter. This lock must be acuqired in ALTER queries to be + /// Lock table for alter. This lock must be acquired in ALTER queries to be /// sure, that we execute only one simultaneous alter. Doesn't affect share lock. TableLockHolder lockForAlter( const String & query_id, @@ -303,12 +292,6 @@ class IStorage bool is_dropped{false}; - /// Does table support index for IN sections - virtual bool supportsIndexForIn() const { return false; } - - /// Provides a hint that the storage engine may evaluate the IN-condition by using an index. - virtual bool mayBenefitFromIndexForIn(const ASTPtr & /* left_in_operand */) const { return false; } - /// Checks validity of the data virtual bool checkData() const { @@ -320,9 +303,6 @@ class IStorage /// Otherwise - throws an exception with detailed information or returns false virtual bool checkTableCanBeDropped() const { return true; } - /** Notify engine about updated dependencies for this storage. */ - virtual void updateDependencies() {} - /// Returns data path if storage supports it, empty string otherwise. virtual String getDataPath() const { return {}; } diff --git a/dbms/src/Storages/KVStore/BackgroundService.h b/dbms/src/Storages/KVStore/BackgroundService.h index 06309532c09..2aa17a972e2 100644 --- a/dbms/src/Storages/KVStore/BackgroundService.h +++ b/dbms/src/Storages/KVStore/BackgroundService.h @@ -20,16 +20,10 @@ #include #include -#include -#include namespace DB { class TMTContext; -class Region; -using RegionPtr = std::shared_ptr; -using Regions = std::vector; -using RegionMap = std::unordered_map; class BackgroundProcessingPool; class BackgroundService : boost::noncopyable diff --git a/dbms/src/Storages/KVStore/Decode/PartitionStreams.cpp b/dbms/src/Storages/KVStore/Decode/PartitionStreams.cpp index 9f55148dd75..4e47246b305 100644 --- a/dbms/src/Storages/KVStore/Decode/PartitionStreams.cpp +++ b/dbms/src/Storages/KVStore/Decode/PartitionStreams.cpp @@ -263,20 +263,15 @@ DM::WriteResult writeRegionDataToStorage( } } -std::variant resolveLocksAndReadRegionData( + +std::variant RegionTable::checkRegionAndGetLocks( const TiDB::TableID table_id, const RegionPtr & region, const Timestamp start_ts, const std::unordered_set * bypass_lock_ts, RegionVersion region_version, - RegionVersion conf_version, - bool resolve_locks, - bool need_data_value) + RegionVersion conf_version) { - LockInfoPtr lock_info; - - auto scanner = region->createCommittedScanner(true, need_data_value); - /// Some sanity checks for region meta. { /** @@ -303,33 +298,12 @@ std::variantcreateCommittedScanner(true, /*need_data_value*/ false); + /// Get transaction locks that should be resolved in this region. + auto lock_info = scanner.getLockInfo(RegionLockReadQuery{.read_tso = start_ts, .bypass_lock_ts = bypass_lock_ts}); if (lock_info) return lock_info; - - /// If there is no lock, leave scope of region scanner and raise LockException. - /// Read raw KVs from region cache. - RegionDataReadInfoList data_list_read; - // Shortcut for empty region. - if (!scanner.hasNext()) - return data_list_read; - - // If worked with raftstore v2, the final size may not equal to here. - data_list_read.reserve(scanner.writeMapSize()); - - // Tiny optimization for queries that need only handle, tso, delmark. - do - { - data_list_read.emplace_back(scanner.next()); - } while (scanner.hasNext()); - return data_list_read; + return RegionException::RegionReadStatus::OK; } std::optional ReadRegionCommitCache(const RegionPtr & region, bool lock_region) @@ -456,42 +430,6 @@ DM::WriteResult RegionTable::writeCommittedByRegion( return write_result; } -RegionTable::ResolveLocksAndWriteRegionRes RegionTable::resolveLocksAndWriteRegion( - TMTContext & tmt, - const TiDB::TableID table_id, - const RegionPtr & region, - const Timestamp start_ts, - const std::unordered_set * bypass_lock_ts, - RegionVersion region_version, - RegionVersion conf_version, - const LoggerPtr & log) -{ - auto region_data_lock = resolveLocksAndReadRegionData( - table_id, - region, - start_ts, - bypass_lock_ts, - region_version, - conf_version, - /* resolve_locks */ true, - /* need_data_value */ true); - - return std::visit( - variant_op::overloaded{ - [&](RegionDataReadInfoList & data_list_read) -> ResolveLocksAndWriteRegionRes { - if (data_list_read.empty()) - return RegionException::RegionReadStatus::OK; - auto & context = tmt.getContext(); - // There is no raft input here, so we can just ignore the fg flush request. - writeRegionDataToStorage(context, region, data_list_read, log); - RemoveRegionCommitCache(region, data_list_read); - return RegionException::RegionReadStatus::OK; - }, - [](auto & r) -> ResolveLocksAndWriteRegionRes { return std::move(r); }, - }, - region_data_lock); -} - // Note that there could be a chance that the table have been totally removed from TiKV // and TiFlash can not get the IStorage instance. // - Check whether the StorageDeltaMerge is nullptr or not before you accessing to it. diff --git a/dbms/src/Storages/KVStore/Decode/PartitionStreams.h b/dbms/src/Storages/KVStore/Decode/PartitionStreams.h index 0fa9c30be53..0f9a56ada41 100644 --- a/dbms/src/Storages/KVStore/Decode/PartitionStreams.h +++ b/dbms/src/Storages/KVStore/Decode/PartitionStreams.h @@ -19,12 +19,11 @@ #include #include #include +#include #include namespace DB { -class Region; -using RegionPtr = std::shared_ptr; class StorageDeltaMerge; class TMTContext; diff --git a/dbms/src/Storages/KVStore/Decode/RegionTable.cpp b/dbms/src/Storages/KVStore/Decode/RegionTable.cpp index 03f4ef535e0..75e10b8bb34 100644 --- a/dbms/src/Storages/KVStore/Decode/RegionTable.cpp +++ b/dbms/src/Storages/KVStore/Decode/RegionTable.cpp @@ -517,16 +517,4 @@ void RegionTable::removeTableFromIndex(KeyspaceID keyspace_id, TableID table_id) } } -RegionPtrWithSnapshotFiles::RegionPtrWithSnapshotFiles( - const Base & base_, - std::vector && external_files_) - : base(base_) - , external_files(std::move(external_files_)) -{} - -RegionPtrWithCheckpointInfo::RegionPtrWithCheckpointInfo(const Base & base_, CheckpointIngestInfoPtr checkpoint_info_) - : base(base_) - , checkpoint_info(std::move(checkpoint_info_)) -{} - } // namespace DB diff --git a/dbms/src/Storages/KVStore/Decode/RegionTable.h b/dbms/src/Storages/KVStore/Decode/RegionTable.h index 22f554752ce..0b62fcb5185 100644 --- a/dbms/src/Storages/KVStore/Decode/RegionTable.h +++ b/dbms/src/Storages/KVStore/Decode/RegionTable.h @@ -18,7 +18,7 @@ #include #include #include -#include +#include #include #include #include @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include @@ -36,25 +36,8 @@ namespace DB { -struct MockRaftCommand; -struct ColumnsDescription; -class IStorage; -using StoragePtr = std::shared_ptr; -class TMTContext; -class IBlockInputStream; -using BlockInputStreamPtr = std::shared_ptr; -class Block; -// for debug -struct MockTiDBTable; class RegionRangeKeys; class RegionTaskLock; -struct RegionPtrWithSnapshotFiles; -class RegionScanFilter; -using RegionScanFilterPtr = std::shared_ptr; -struct CheckpointInfo; -using CheckpointInfoPtr = std::shared_ptr; -struct CheckpointIngestInfo; -using CheckpointIngestInfoPtr = std::shared_ptr; class RegionTable : private boost::noncopyable { @@ -97,7 +80,7 @@ class RegionTable : private boost::noncopyable // Most of the regions are scheduled to TiFlash by a raft snapshot. void addPrehandlingRegion(const Region & region); - // When a reigon is removed out of TiFlash. + // When a region is removed out of TiFlash. void removeRegion(RegionID region_id, bool remove_data, const RegionTaskLock &); // Used by apply snapshot. @@ -136,18 +119,14 @@ class RegionTable : private boost::noncopyable const LoggerPtr & log, bool lock_region = true); - /// Check transaction locks in region, and write committed data in it into storage engine if check passed. Otherwise throw an LockException. - /// The write logic is the same as #writeCommittedByRegion, with some extra checks about region version and conf_version. - using ResolveLocksAndWriteRegionRes = std::variant; - static ResolveLocksAndWriteRegionRes resolveLocksAndWriteRegion( - TMTContext & tmt, + /// Check region metas and get transaction locks in region. + static std::variant checkRegionAndGetLocks( const TableID table_id, const RegionPtr & region, const Timestamp start_ts, const std::unordered_set * bypass_lock_ts, RegionVersion region_version, - RegionVersion conf_version, - const LoggerPtr & log); + RegionVersion conf_version); void clear(); @@ -190,41 +169,4 @@ class RegionTable : private boost::noncopyable }; -// A wrap of RegionPtr, with snapshot files directory waiting to be ingested -struct RegionPtrWithSnapshotFiles -{ - using Base = RegionPtr; - - /// can accept const ref of RegionPtr without cache - RegionPtrWithSnapshotFiles(const Base & base_, std::vector && external_files_ = {}); - - /// to be compatible with usage as RegionPtr. - Base::element_type * operator->() const { return base.operator->(); } - const Base::element_type & operator*() const { return base.operator*(); } - - /// make it could be cast into RegionPtr implicitly. - operator const Base &() const { return base; } - - const Base & base; - const std::vector external_files; -}; - -// A wrap of RegionPtr, with checkpoint info to be ingested -struct RegionPtrWithCheckpointInfo -{ - using Base = RegionPtr; - - RegionPtrWithCheckpointInfo(const Base & base_, CheckpointIngestInfoPtr checkpoint_info_); - - /// to be compatible with usage as RegionPtr. - Base::element_type * operator->() const { return base.operator->(); } - const Base::element_type & operator*() const { return base.operator*(); } - - /// make it could be cast into RegionPtr implicitly. - operator const Base &() const { return base; } - - const Base & base; - CheckpointIngestInfoPtr checkpoint_info; -}; - } // namespace DB diff --git a/dbms/src/Storages/KVStore/FFI/ProxyFFI.cpp b/dbms/src/Storages/KVStore/FFI/ProxyFFI.cpp index 9b1be1b83b2..53409c652f5 100644 --- a/dbms/src/Storages/KVStore/FFI/ProxyFFI.cpp +++ b/dbms/src/Storages/KVStore/FFI/ProxyFFI.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include diff --git a/dbms/src/Storages/KVStore/KVStore.h b/dbms/src/Storages/KVStore/KVStore.h index f93f6be9947..310282a378a 100644 --- a/dbms/src/Storages/KVStore/KVStore.h +++ b/dbms/src/Storages/KVStore/KVStore.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -54,8 +55,6 @@ class KVStore; using KVStorePtr = std::shared_ptr; class RegionTable; -class Region; -using RegionPtr = std::shared_ptr; struct RaftCommandResult; class KVStoreTaskLock; @@ -77,8 +76,6 @@ class ReadIndexStressTest; struct FileUsageStatistics; class PathPool; class RegionPersister; -struct CheckpointInfo; -using CheckpointInfoPtr = std::shared_ptr; struct CheckpointIngestInfo; using CheckpointIngestInfoPtr = std::shared_ptr; class UniversalPageStorage; diff --git a/dbms/src/Storages/KVStore/MultiRaft/ApplySnapshot.cpp b/dbms/src/Storages/KVStore/MultiRaft/ApplySnapshot.cpp index f2fc5e547b3..3d013a119e9 100644 --- a/dbms/src/Storages/KVStore/MultiRaft/ApplySnapshot.cpp +++ b/dbms/src/Storages/KVStore/MultiRaft/ApplySnapshot.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -43,6 +44,18 @@ extern const int LOGICAL_ERROR; extern const int TABLE_IS_DROPPED; } // namespace ErrorCodes +RegionPtrWithSnapshotFiles::RegionPtrWithSnapshotFiles( + const Base & base_, + std::vector && external_files_) + : base(base_) + , external_files(std::move(external_files_)) +{} + +RegionPtrWithCheckpointInfo::RegionPtrWithCheckpointInfo(const Base & base_, CheckpointIngestInfoPtr checkpoint_info_) + : base(base_) + , checkpoint_info(std::move(checkpoint_info_)) +{} + template void KVStore::checkAndApplyPreHandledSnapshot(const RegionPtrWrap & new_region, TMTContext & tmt) { diff --git a/dbms/src/Storages/KVStore/MultiRaft/ApplySnapshot.h b/dbms/src/Storages/KVStore/MultiRaft/ApplySnapshot.h new file mode 100644 index 00000000000..d5948ec7a8d --- /dev/null +++ b/dbms/src/Storages/KVStore/MultiRaft/ApplySnapshot.h @@ -0,0 +1,65 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace DB +{ + +struct CheckpointIngestInfo; +using CheckpointIngestInfoPtr = std::shared_ptr; + +// A wrap of RegionPtr, with snapshot files directory waiting to be ingested +struct RegionPtrWithSnapshotFiles +{ + using Base = RegionPtr; + + /// can accept const ref of RegionPtr without cache + RegionPtrWithSnapshotFiles( // NOLINT(google-explicit-constructor) + const Base & base_, + std::vector && external_files_ = {}); + + /// to be compatible with usage as RegionPtr. + Base::element_type * operator->() const { return base.operator->(); } + const Base::element_type & operator*() const { return base.operator*(); } + + /// make it could be cast into RegionPtr implicitly. + operator const Base &() const { return base; } // NOLINT(google-explicit-constructor) + + const Base & base; + const std::vector external_files; +}; + +// A wrap of RegionPtr, with checkpoint info to be ingested +struct RegionPtrWithCheckpointInfo +{ + using Base = RegionPtr; + + RegionPtrWithCheckpointInfo(const Base & base_, CheckpointIngestInfoPtr checkpoint_info_); + + /// to be compatible with usage as RegionPtr. + Base::element_type * operator->() const { return base.operator->(); } + const Base::element_type & operator*() const { return base.operator*(); } + + /// make it could be cast into RegionPtr implicitly. + operator const Base &() const { return base; } // NOLINT(google-explicit-constructor) + + const Base & base; + CheckpointIngestInfoPtr checkpoint_info; +}; + +} // namespace DB diff --git a/dbms/src/Storages/KVStore/MultiRaft/Disagg/CheckpointIngestInfo.cpp b/dbms/src/Storages/KVStore/MultiRaft/Disagg/CheckpointIngestInfo.cpp index 8b708cd18fa..874f68534d6 100644 --- a/dbms/src/Storages/KVStore/MultiRaft/Disagg/CheckpointIngestInfo.cpp +++ b/dbms/src/Storages/KVStore/MultiRaft/Disagg/CheckpointIngestInfo.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include diff --git a/dbms/src/Storages/KVStore/MultiRaft/Disagg/CheckpointIngestInfo.h b/dbms/src/Storages/KVStore/MultiRaft/Disagg/CheckpointIngestInfo.h index 238f6621855..5b1cddad95f 100644 --- a/dbms/src/Storages/KVStore/MultiRaft/Disagg/CheckpointIngestInfo.h +++ b/dbms/src/Storages/KVStore/MultiRaft/Disagg/CheckpointIngestInfo.h @@ -17,14 +17,13 @@ #include #include #include +#include #include #include namespace DB { -class Region; -using RegionPtr = std::shared_ptr; class TMTContext; class UniversalPageStorage; using UniversalPageStoragePtr = std::shared_ptr; diff --git a/dbms/src/Storages/KVStore/MultiRaft/Disagg/FastAddPeer.h b/dbms/src/Storages/KVStore/MultiRaft/Disagg/FastAddPeer.h index a67adbd23d3..52aceb45675 100644 --- a/dbms/src/Storages/KVStore/MultiRaft/Disagg/FastAddPeer.h +++ b/dbms/src/Storages/KVStore/MultiRaft/Disagg/FastAddPeer.h @@ -15,14 +15,14 @@ #pragma once #include +#include #include +#include namespace DB { struct CheckpointInfo; using CheckpointInfoPtr = std::shared_ptr; -class Region; -using RegionPtr = std::shared_ptr; using CheckpointRegionInfoAndData = std::tuple; FastAddPeerRes genFastAddPeerRes( diff --git a/dbms/src/Storages/KVStore/MultiRaft/Disagg/FastAddPeerContext.h b/dbms/src/Storages/KVStore/MultiRaft/Disagg/FastAddPeerContext.h index 0ec62668cbc..bc4801e3bfb 100644 --- a/dbms/src/Storages/KVStore/MultiRaft/Disagg/FastAddPeerContext.h +++ b/dbms/src/Storages/KVStore/MultiRaft/Disagg/FastAddPeerContext.h @@ -18,6 +18,7 @@ #include #include #include +#include #include namespace DB @@ -26,8 +27,6 @@ class FastAddPeerContext; using FAPAsyncTasks = AsyncTasks, FastAddPeerRes>; struct CheckpointInfo; using CheckpointInfoPtr = std::shared_ptr; -class Region; -using RegionPtr = std::shared_ptr; using CheckpointRegionInfoAndData = std::tuple; diff --git a/dbms/src/Storages/KVStore/MultiRaft/PrehandleSnapshot.cpp b/dbms/src/Storages/KVStore/MultiRaft/PrehandleSnapshot.cpp index 767d5edca21..43e430c95be 100644 --- a/dbms/src/Storages/KVStore/MultiRaft/PrehandleSnapshot.cpp +++ b/dbms/src/Storages/KVStore/MultiRaft/PrehandleSnapshot.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include diff --git a/dbms/src/Storages/KVStore/MultiRaft/RegionData.h b/dbms/src/Storages/KVStore/MultiRaft/RegionData.h index 2ae3810520d..7c3083282d6 100644 --- a/dbms/src/Storages/KVStore/MultiRaft/RegionData.h +++ b/dbms/src/Storages/KVStore/MultiRaft/RegionData.h @@ -30,7 +30,6 @@ using DecodedLockCFValuePtr = std::shared_ptr +#include #include #include @@ -28,8 +29,6 @@ class LockInfo; namespace DB { -class Region; -using RegionPtr = std::shared_ptr; class RegionRangeKeys; using ImutRegionRangePtr = std::shared_ptr; @@ -48,7 +47,7 @@ struct RaftCommandResult : private boost::noncopyable bool sync_log; Type type = Type::Default; - std::vector split_regions{}; + Regions split_regions{}; ImutRegionRangePtr ori_region_range; RegionID source_region_id; }; diff --git a/dbms/src/Storages/KVStore/MultiRaft/RegionMeta.h b/dbms/src/Storages/KVStore/MultiRaft/RegionMeta.h index 36997b5a8ee..2a6a0f7c809 100644 --- a/dbms/src/Storages/KVStore/MultiRaft/RegionMeta.h +++ b/dbms/src/Storages/KVStore/MultiRaft/RegionMeta.h @@ -32,7 +32,6 @@ class RegionKVStoreOldTest; } // namespace tests struct RegionMergeResult; -class Region; class MetaRaftCommandDelegate; class RegionRaftCommandDelegate; @@ -62,6 +61,7 @@ struct RegionMetaSnapshot class RegionMeta { public: + // For deserialize from buffer RegionMeta( metapb::Peer peer_, raft_serverpb::RaftApplyState apply_state_, @@ -157,8 +157,6 @@ class RegionMeta const RegionID region_id; }; -// TODO: Integrate initialApplyState to MockTiKV - // When we create a region peer, we should initialize its log term/index > 0, // so that we can force the follower peer to sync the snapshot first. static constexpr UInt64 RAFT_INIT_LOG_TERM = 5; diff --git a/dbms/src/Storages/KVStore/MultiRaft/RegionPersister.h b/dbms/src/Storages/KVStore/MultiRaft/RegionPersister.h index 6dc695b2b0b..eb4c35c2281 100644 --- a/dbms/src/Storages/KVStore/MultiRaft/RegionPersister.h +++ b/dbms/src/Storages/KVStore/MultiRaft/RegionPersister.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -25,9 +26,6 @@ namespace DB { class PathPool; -class Region; -using RegionPtr = std::shared_ptr; -using RegionMap = std::unordered_map; class RegionTaskLock; struct RegionManager; diff --git a/dbms/src/Storages/KVStore/MultiRaft/RegionsRangeIndex.h b/dbms/src/Storages/KVStore/MultiRaft/RegionsRangeIndex.h index 15e5fba2615..37d34f61fd1 100644 --- a/dbms/src/Storages/KVStore/MultiRaft/RegionsRangeIndex.h +++ b/dbms/src/Storages/KVStore/MultiRaft/RegionsRangeIndex.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include @@ -28,10 +29,6 @@ class KVStoreTestBase; class RegionKVStoreOldTest; } // namespace tests -class Region; -using RegionPtr = std::shared_ptr; -using RegionMap = std::unordered_map; - struct TiKVRangeKey; using RegionRange = RegionRangeKeys::RegionRange; diff --git a/dbms/src/Storages/KVStore/MultiRaft/Spill/RegionUncommittedDataList.h b/dbms/src/Storages/KVStore/MultiRaft/Spill/RegionUncommittedDataList.h index a5ff515d2fe..b45ef539fbd 100644 --- a/dbms/src/Storages/KVStore/MultiRaft/Spill/RegionUncommittedDataList.h +++ b/dbms/src/Storages/KVStore/MultiRaft/Spill/RegionUncommittedDataList.h @@ -14,7 +14,6 @@ #pragma once -#include #include namespace DB @@ -70,4 +69,4 @@ struct RegionUncommittedDataList // Timestamp start_ts; }; -} // namespace DB \ No newline at end of file +} // namespace DB diff --git a/dbms/src/Storages/KVStore/ProxyStateMachine.h b/dbms/src/Storages/KVStore/ProxyStateMachine.h index 56525fc436b..1cb58372758 100644 --- a/dbms/src/Storages/KVStore/ProxyStateMachine.h +++ b/dbms/src/Storages/KVStore/ProxyStateMachine.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include diff --git a/dbms/src/Storages/KVStore/Read/LearnerRead.h b/dbms/src/Storages/KVStore/Read/LearnerRead.h index 51e832b7f6b..7bfa5c21c65 100644 --- a/dbms/src/Storages/KVStore/Read/LearnerRead.h +++ b/dbms/src/Storages/KVStore/Read/LearnerRead.h @@ -20,7 +20,6 @@ #include #include -#include namespace DB diff --git a/dbms/src/Storages/KVStore/Read/LearnerReadWorker.cpp b/dbms/src/Storages/KVStore/Read/LearnerReadWorker.cpp index 3a6f5f2c82f..0cab437cd00 100644 --- a/dbms/src/Storages/KVStore/Read/LearnerReadWorker.cpp +++ b/dbms/src/Storages/KVStore/Read/LearnerReadWorker.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -406,17 +407,14 @@ void LearnerReadWorker::waitIndex( continue; } - // Try to resolve locks and flush data into storage layer const auto & physical_table_id = region_to_query.physical_table_id; - auto res = RegionTable::resolveLocksAndWriteRegion( - tmt, + auto res = RegionTable::checkRegionAndGetLocks( physical_table_id, region, mvcc_query_info.start_ts, region_to_query.bypass_lock_ts, region_to_query.version, - region_to_query.conf_version, - log); + region_to_query.conf_version); std::visit( variant_op::overloaded{ diff --git a/dbms/src/Storages/KVStore/Region.cpp b/dbms/src/Storages/KVStore/Region.cpp index 5f1fe795060..2f3322adbdc 100644 --- a/dbms/src/Storages/KVStore/Region.cpp +++ b/dbms/src/Storages/KVStore/Region.cpp @@ -335,16 +335,6 @@ Region::~Region() GET_METRIC(tiflash_raft_classes_count, type_region).Decrement(); } -TableID Region::getMappedTableID() const -{ - return mapped_table_id; -} - -KeyspaceID Region::getKeyspaceID() const -{ - return keyspace_id; -} - void Region::setPeerState(raft_serverpb::PeerState state) { meta.setPeerState(state); @@ -372,7 +362,8 @@ std::pair Region::getApproxMemCacheInfo() const { return { approx_mem_cache_rows.load(std::memory_order_relaxed), - approx_mem_cache_bytes.load(std::memory_order_relaxed)}; + approx_mem_cache_bytes.load(std::memory_order_relaxed), + }; } void Region::cleanApproxMemCacheInfo() const @@ -409,7 +400,7 @@ void Region::setRegionTableCtx(RegionTableCtxPtr ctx) const void Region::maybeWarnMemoryLimitByTable(TMTContext & tmt, const char * from) { - // If there are data flow in, we will check if the memory is exhaused. + // If there are data flow in, we will check if the memory is exhausted. auto limit = tmt.getKVStore()->getKVStoreMemoryLimit(); size_t current = real_rss.load() > 0 ? real_rss.load() : 0; if unlikely (limit == 0 || current == 0) diff --git a/dbms/src/Storages/KVStore/Region.h b/dbms/src/Storages/KVStore/Region.h index 9760d1ad0ef..6d6b60b552e 100644 --- a/dbms/src/Storages/KVStore/Region.h +++ b/dbms/src/Storages/KVStore/Region.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -45,9 +46,6 @@ class RegionKVStoreOldTest; class RegionKVStoreTest; } // namespace tests -class Region; -using RegionPtr = std::shared_ptr; -using Regions = std::vector; struct RaftCommandResult; class KVStore; @@ -229,8 +227,9 @@ class Region : public std::enable_shared_from_this RegionVersion version() const; RegionVersion confVer() const; - TableID getMappedTableID() const; - KeyspaceID getKeyspaceID() const; + TableID getMappedTableID() const { return mapped_table_id; } + KeyspaceID getKeyspaceID() const { return keyspace_id; } + KeyspaceTableID getKeyspaceTableID() const { return KeyspaceTableID{keyspace_id, mapped_table_id}; } /// get approx rows, bytes info about mem cache. std::pair getApproxMemCacheInfo() const; diff --git a/dbms/src/Storages/KVStore/Region_fwd.h b/dbms/src/Storages/KVStore/Region_fwd.h new file mode 100644 index 00000000000..cabdf4d2837 --- /dev/null +++ b/dbms/src/Storages/KVStore/Region_fwd.h @@ -0,0 +1,29 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include + +namespace DB +{ +class Region; +using RegionPtr = std::shared_ptr; +using Regions = std::vector; +using RegionMap = std::unordered_map; +} // namespace DB diff --git a/dbms/src/Storages/KVStore/TMTContext.cpp b/dbms/src/Storages/KVStore/TMTContext.cpp index 9568acfcf97..7428fad145e 100644 --- a/dbms/src/Storages/KVStore/TMTContext.cpp +++ b/dbms/src/Storages/KVStore/TMTContext.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include diff --git a/dbms/src/Storages/KVStore/tests/gtest_kvstore.cpp b/dbms/src/Storages/KVStore/tests/gtest_kvstore.cpp index 2f0b07d5c37..cac84fb5800 100644 --- a/dbms/src/Storages/KVStore/tests/gtest_kvstore.cpp +++ b/dbms/src/Storages/KVStore/tests/gtest_kvstore.cpp @@ -13,7 +13,9 @@ // limitations under the License. #include +#include #include +#include #include #include @@ -22,6 +24,8 @@ namespace DB::tests { +using namespace RegionBench; + class RegionKVStoreOldTest : public KVStoreTestBase { public: @@ -578,11 +582,7 @@ void RegionKVStoreOldTest::testRaftMerge(Context & ctx, KVStore & kvs, TMTContex // add 7 back auto task_lock = kvs.genTaskLock(); auto lock = kvs.genRegionMgrWriteLock(task_lock); - auto region = makeRegion( - source_region_id, - RecordKVFormat::genKey(table_id, 0), - RecordKVFormat::genKey(table_id, 5), - kvs.getProxyHelper()); + auto region = makeRegionForTable(source_region_id, table_id, 0, 5, kvs.getProxyHelper()); lock.regions.emplace(source_region_id, region); lock.index.add(region); } @@ -626,14 +626,6 @@ TEST_F(RegionKVStoreOldTest, RegionReadWrite) std::make_optional( std::make_pair(RecordKVFormat::genKey(table_id, 0), RecordKVFormat::genKey(table_id, 1000)))); auto region = kvs.getRegion(region_id); - { - // Test create RegionMeta. - auto meta = RegionMeta( - createPeer(2, true), - createRegionInfo(666, RecordKVFormat::genKey(0, 0), RecordKVFormat::genKey(0, 1000)), - initialApplyState()); - ASSERT_EQ(meta.peerId(), 2); - } { // Test GenRegionReadIndexReq. ASSERT_TRUE(region->checkIndex(5)); @@ -1108,11 +1100,7 @@ try }); // Initially region_19 range is [0, 10000) { - auto region = makeRegion( - region_id, - RecordKVFormat::genKey(table_id, 0), - RecordKVFormat::genKey(table_id, 10000), - kvs.getProxyHelper()); + auto region = makeRegionForTable(region_id, table_id, 0, 10000, kvs.getProxyHelper()); // Fill data from 20 to 100. GenMockSSTData(DMTestEnv::getMinimalTableInfo(table_id), table_id, region_id_str, 20, 100, 0); std::vector sst_views{ @@ -1150,11 +1138,7 @@ try } // Later, its range is changed to [20000, 50000) { - auto region = makeRegion( - region_id, - RecordKVFormat::genKey(table_id, 20000), - RecordKVFormat::genKey(table_id, 50000), - kvs.getProxyHelper()); + auto region = makeRegionForTable(region_id, table_id, 20000, 50000, kvs.getProxyHelper()); // Fill data from 20100 to 20200. GenMockSSTData(DMTestEnv::getMinimalTableInfo(table_id), table_id, region_id_str, 20100, 20200, 0); std::vector sst_views{ @@ -1239,11 +1223,7 @@ try TableID table_id = 1; auto region_id = 19; - auto region = makeRegion( - region_id, - RecordKVFormat::genKey(table_id, 50), - RecordKVFormat::genKey(table_id, 60), - kvs.getProxyHelper()); + auto region = makeRegionForTable(region_id, table_id, 50, 60, kvs.getProxyHelper()); { // Prepare a region with some kvs @@ -1336,11 +1316,7 @@ try // Snapshot will be rejected if region overlaps with existing Region. { // create an empty region 22, range=[50,100) - auto region = makeRegion( - 22, - RecordKVFormat::genKey(table_id, 50), - RecordKVFormat::genKey(table_id, 100), - kvs.getProxyHelper()); + auto region = makeRegionForTable(22, table_id, 50, 100, kvs.getProxyHelper()); auto prehandle_result = kvs.preHandleSnapshotToFiles(region, {}, 9, 5, std::nullopt, ctx.getTMTContext()); kvs.checkAndApplyPreHandledSnapshot( RegionPtrWithSnapshotFiles{region, std::move(prehandle_result.ingest_ids)}, @@ -1353,11 +1329,7 @@ try try { // try apply snapshot to region 20, range=[50, 100) that is overlapped with region 22, should be rejected - auto region = makeRegion( - 20, - RecordKVFormat::genKey(table_id, 50), - RecordKVFormat::genKey(table_id, 100), - kvs.getProxyHelper()); + auto region = makeRegionForTable(20, table_id, 50, 100, kvs.getProxyHelper()); auto prehandle_result = kvs.preHandleSnapshotToFiles(region, {}, 9, 5, std::nullopt, ctx.getTMTContext()); kvs.checkAndApplyPreHandledSnapshot( RegionPtrWithSnapshotFiles{region, std::move(prehandle_result.ingest_ids)}, @@ -1377,11 +1349,7 @@ try try { - auto region = makeRegion( - 20, - RecordKVFormat::genKey(table_id, 50), - RecordKVFormat::genKey(table_id, 100), - kvs.getProxyHelper()); + auto region = makeRegionForTable(20, table_id, 50, 100, kvs.getProxyHelper()); // preHandleSnapshotToFiles will assert proxy_ptr is not null. auto prehandle_result = kvs.preHandleSnapshotToFiles(region, {}, 10, 5, std::nullopt, ctx.getTMTContext()); proxy_helper->proxy_ptr.inner = nullptr; @@ -1403,11 +1371,7 @@ try s.set_state(::raft_serverpb::PeerState::Tombstone); s; })); - auto region = makeRegion( - 20, - RecordKVFormat::genKey(table_id, 50), - RecordKVFormat::genKey(table_id, 100), - kvs.getProxyHelper()); + auto region = makeRegionForTable(20, table_id, 50, 100, kvs.getProxyHelper()); auto prehandle_result = kvs.preHandleSnapshotToFiles(region, {}, 10, 5, std::nullopt, ctx.getTMTContext()); kvs.checkAndApplyPreHandledSnapshot( RegionPtrWithSnapshotFiles{region, std::move(prehandle_result.ingest_ids)}, @@ -1443,8 +1407,7 @@ try auto region_id_str = std::to_string(region_id); // Prepare a region with some kvs { - auto region - = makeRegion(region_id, RecordKVFormat::genKey(1, 50), RecordKVFormat::genKey(1, 60), kvs.getProxyHelper()); + auto region = makeRegionForTable(region_id, 1, 50, 60, kvs.getProxyHelper()); auto & mmp = MockSSTReader::getMockSSTData(); MockSSTReader::getMockSSTData().clear(); MockSSTReader::Data default_kv_list; @@ -1542,19 +1505,19 @@ TEST_F(RegionKVStoreOldTest, RegionRange) const auto & root_map = region_index.getRoot(); ASSERT_EQ(root_map.size(), 2); // start and end all equals empty - region_index.add(makeRegion(1, RecordKVFormat::genKey(1, 0), RecordKVFormat::genKey(1, 10))); + region_index.add(makeRegionForTable(1, 1, 0, 10)); ASSERT_EQ(root_map.begin()->second.region_map.size(), 0); - region_index.add(makeRegion(2, RecordKVFormat::genKey(1, 0), RecordKVFormat::genKey(1, 3))); - region_index.add(makeRegion(3, RecordKVFormat::genKey(1, 0), RecordKVFormat::genKey(1, 1))); + region_index.add(makeRegionForTable(2, 1, 0, 3)); + region_index.add(makeRegionForTable(3, 1, 0, 1)); auto res = region_index.findByRangeOverlap(RegionRangeKeys::makeComparableKeys(TiKVKey(""), TiKVKey(""))); ASSERT_EQ(res.size(), 3); auto res2 = region_index.findByRangeChecked(RegionRangeKeys::makeComparableKeys(TiKVKey(""), TiKVKey(""))); ASSERT_TRUE(std::holds_alternative(res2)); - region_index.add(makeRegion(4, RecordKVFormat::genKey(1, 1), RecordKVFormat::genKey(1, 4))); + region_index.add(makeRegionForTable(4, 1, 1, 4)); // -inf,0,1,3,4,10,inf ASSERT_EQ(root_map.size(), 7); @@ -1620,7 +1583,7 @@ TEST_F(RegionKVStoreOldTest, RegionRange) ASSERT_TRUE(std::regex_match(res, msg_reg)); } - region_index.add(makeRegion(2, RecordKVFormat::genKey(1, 3), RecordKVFormat::genKey(1, 5))); + region_index.add(makeRegionForTable(2, 1, 3, 5)); try { region_index.remove( @@ -1679,7 +1642,7 @@ TEST_F(RegionKVStoreOldTest, RegionRange) try { - region_index.add(makeRegion(6, RecordKVFormat::genKey(6, 6), RecordKVFormat::genKey(6, 6))); + region_index.add(makeRegionForTable(6, 6, 6, 6)); assert(false); } catch (Exception & e) @@ -1692,9 +1655,9 @@ TEST_F(RegionKVStoreOldTest, RegionRange) region_index.clear(); - region_index.add(makeRegion(1, RecordKVFormat::genKey(1, 0), RecordKVFormat::genKey(1, 1))); - region_index.add(makeRegion(2, RecordKVFormat::genKey(1, 1), RecordKVFormat::genKey(1, 2))); - region_index.add(makeRegion(3, RecordKVFormat::genKey(1, 2), RecordKVFormat::genKey(1, 3))); + region_index.add(makeRegionForTable(1, 1, 0, 1)); + region_index.add(makeRegionForTable(2, 1, 1, 2)); + region_index.add(makeRegionForTable(3, 1, 2, 3)); ASSERT_EQ(root_map.size(), 6); region_index.remove( @@ -1714,30 +1677,31 @@ TEST_F(RegionKVStoreOldTest, RegionRange) } // Test region range with merge. { + using RegionBench::createMetaRegionCommonHandle; { // Compute `source_at_left` by region range. ASSERT_EQ( MetaRaftCommandDelegate::computeRegionMergeResult( - createRegionInfo(1, "x", ""), - createRegionInfo(1000, "", "x")) + createMetaRegionCommonHandle(1, "x", ""), + createMetaRegionCommonHandle(1000, "", "x")) .source_at_left, false); ASSERT_EQ( MetaRaftCommandDelegate::computeRegionMergeResult( - createRegionInfo(1, "", "x"), - createRegionInfo(1000, "x", "")) + createMetaRegionCommonHandle(1, "", "x"), + createMetaRegionCommonHandle(1000, "x", "")) .source_at_left, true); ASSERT_EQ( MetaRaftCommandDelegate::computeRegionMergeResult( - createRegionInfo(1, "x", "y"), - createRegionInfo(1000, "y", "z")) + createMetaRegionCommonHandle(1, "x", "y"), + createMetaRegionCommonHandle(1000, "y", "z")) .source_at_left, true); ASSERT_EQ( MetaRaftCommandDelegate::computeRegionMergeResult( - createRegionInfo(1, "y", "z"), - createRegionInfo(1000, "x", "y")) + createMetaRegionCommonHandle(1, "y", "z"), + createMetaRegionCommonHandle(1000, "x", "y")) .source_at_left, false); } diff --git a/dbms/src/Storages/KVStore/tests/gtest_kvstore_fast_add_peer.cpp b/dbms/src/Storages/KVStore/tests/gtest_kvstore_fast_add_peer.cpp index 08b5c9ed30b..a6b465e089f 100644 --- a/dbms/src/Storages/KVStore/tests/gtest_kvstore_fast_add_peer.cpp +++ b/dbms/src/Storages/KVStore/tests/gtest_kvstore_fast_add_peer.cpp @@ -409,7 +409,8 @@ std::vector RegionKVStoreTestFAP::prepareForRestart kvs.debugGetConfigMut().debugSetCompactLogConfig(0, 0, 0, 0); if (opt.mock_add_new_peer) { - *kvs.getRegion(id)->mutMeta().debugMutRegionState().getMutRegion().add_peers() = createPeer(peer_id, true); + *kvs.getRegion(id)->mutMeta().debugMutRegionState().getMutRegion().add_peers() + = RegionBench::createPeer(peer_id, true); proxy_instance->getRegion(id)->addPeer(store_id, peer_id, metapb::PeerRole::Learner); } persistAfterWrite(global_context, kvs, proxy_instance, page_storage, id, index); diff --git a/dbms/src/Storages/KVStore/tests/gtest_learner_read.cpp b/dbms/src/Storages/KVStore/tests/gtest_learner_read.cpp index d2077ff541f..9baf09cbe40 100644 --- a/dbms/src/Storages/KVStore/tests/gtest_learner_read.cpp +++ b/dbms/src/Storages/KVStore/tests/gtest_learner_read.cpp @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include #include #include -#include #include #include #include @@ -76,34 +76,23 @@ try const TableID table_id = 100; + using RegionBench::makeRegionForTable; LearnerReadSnapshot snapshot{ { region_id_200, - RegionLearnerReadSnapshot(makeRegion( - region_id_200, - RecordKVFormat::genKey(table_id, 0), - RecordKVFormat::genKey(table_id, 10000))), + RegionLearnerReadSnapshot(makeRegionForTable(region_id_200, table_id, 0, 10000)), }, { region_id_201, - RegionLearnerReadSnapshot(makeRegion( - region_id_201, - RecordKVFormat::genKey(table_id, 10000), - RecordKVFormat::genKey(table_id, 20000))), + RegionLearnerReadSnapshot(makeRegionForTable(region_id_201, table_id, 10000, 20000)), }, { region_id_202, - RegionLearnerReadSnapshot(makeRegion( - region_id_202, - RecordKVFormat::genKey(table_id, 20000), - RecordKVFormat::genKey(table_id, 30000))), + RegionLearnerReadSnapshot(makeRegionForTable(region_id_202, table_id, 20000, 30000)), }, { region_id_203, - RegionLearnerReadSnapshot(makeRegion( - region_id_203, - RecordKVFormat::genKey(table_id, 30000), - RecordKVFormat::genKey(table_id, 40000))), + RegionLearnerReadSnapshot(makeRegionForTable(region_id_203, table_id, 30000, 40000)), }, }; diff --git a/dbms/src/Storages/KVStore/tests/gtest_memory.cpp b/dbms/src/Storages/KVStore/tests/gtest_memory.cpp index 3e395dfec25..772b75ceb1d 100644 --- a/dbms/src/Storages/KVStore/tests/gtest_memory.cpp +++ b/dbms/src/Storages/KVStore/tests/gtest_memory.cpp @@ -12,12 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include +#include #include +#include #include +#include +#include #include #include +#include #include #include #include // Included for `USE_JEMALLOC` @@ -106,7 +112,7 @@ try ASSERT_EQ(kvs.debug_memory_limit_warning_count, 1); } { - // lock with largetxn + // lock with large txn root_of_kvstore_mem_trackers->reset(); RegionID region_id = 4300; auto [start, end] = getStartEnd(region_id); @@ -118,7 +124,7 @@ try auto kvr2 = kvs.getRegion(4200); auto kvr3 = kvs.getRegion(region_id); ASSERT_NE(kvr3, nullptr); - std::string shor_value = "value"; + std::string short_value = "value"; auto lock_for_update_ts = 7777, txn_size = 1; const std::vector & async_commit = {"s1", "s2"}; const std::vector & rollback = {3, 4}; @@ -127,7 +133,7 @@ try "primary key", 421321, std::numeric_limits::max(), - &shor_value, + &short_value, 66666, lock_for_update_ts, txn_size, @@ -151,7 +157,7 @@ try // insert & remove root_of_kvstore_mem_trackers->reset(); RegionID region_id = 5000; - auto originTableSize = region_table.getTableRegionSize(NullspaceID, table_id); + auto origin_table_size = region_table.getTableRegionSize(NullspaceID, table_id); auto [start, end] = getStartEnd(region_id); auto str_key = pickKey(region_id, 1); auto [str_val_write, str_val_default] = pickWriteDefault(region_id, 1); @@ -161,12 +167,12 @@ try region->insertFromSnap(tmt, "default", TiKVKey::copyFrom(str_key), TiKVValue::copyFrom(str_val_default)); auto delta = str_key.dataSize() + str_val_default.size(); ASSERT_EQ(root_of_kvstore_mem_trackers->get(), delta); - ASSERT_EQ(region_table.getTableRegionSize(NullspaceID, table_id), originTableSize + delta); + ASSERT_EQ(region_table.getTableRegionSize(NullspaceID, table_id), origin_table_size + delta); region->removeDebug("default", TiKVKey::copyFrom(str_key)); ASSERT_EQ(root_of_kvstore_mem_trackers->get(), 0); ASSERT_EQ(region->dataSize(), root_of_kvstore_mem_trackers->get()); ASSERT_EQ(region->dataSize(), region->getData().totalSize()); - ASSERT_EQ(region_table.getTableRegionSize(NullspaceID, table_id), originTableSize); + ASSERT_EQ(region_table.getTableRegionSize(NullspaceID, table_id), origin_table_size); ASSERT_EQ(kvs.debug_memory_limit_warning_count, 1); } ASSERT_EQ(root_of_kvstore_mem_trackers->get(), 0); @@ -274,11 +280,8 @@ try auto new_region = splitRegion( region, RegionMeta( - createPeer(region_id + 1, true), - createRegionInfo( - region_id2, - RecordKVFormat::genKey(table_id, 12050), - RecordKVFormat::genKey(table_id, 12099)), + RegionBench::createPeer(region_id + 1, true), + RegionBench::createMetaRegion(region_id2, table_id, 12050, 12099), initialApplyState())); ASSERT_EQ(original_size, region_table.getTableRegionSize(NullspaceID, table_id)); ASSERT_EQ(root_of_kvstore_mem_trackers->get(), expected); @@ -321,11 +324,8 @@ try auto new_region = splitRegion( region, RegionMeta( - createPeer(region_id + 1, true), - createRegionInfo( - region_id2, - RecordKVFormat::genKey(table_id, 13150), - RecordKVFormat::genKey(table_id, 13199)), + RegionBench::createPeer(region_id + 1, true), + RegionBench::createMetaRegion(region_id2, table_id, 13150, 13199), initialApplyState())); ASSERT_EQ(original_size, region_table.getTableRegionSize(NullspaceID, table_id)); ASSERT_EQ(root_of_kvstore_mem_trackers->get(), expected); @@ -536,6 +536,169 @@ try CATCH +std::tuple genPreHandlingRegion( + KVStore & kvs, + RegionID region_id, + TableID table_id, + HandleID start, + HandleID end, + RegionTable & region_table, + TMTContext & tmt) +{ + UInt64 peer_id = 100000 + region_id; // gen a fake peer_id + auto meta = RegionBench::createMetaRegion( + region_id, + table_id, + start, + end, + /*maybe_epoch=*/std::nullopt, + /*maybe_peers=*/std::vector{RegionBench::createPeer(peer_id, true)}); + auto new_region = kvs.genRegionPtr( + std::move(meta), + peer_id, + /*index*/ RAFT_INIT_LOG_INDEX, + /*term*/ RAFT_INIT_LOG_TERM, + region_table); + // empty snapshot + SSTViewVec snaps{.views = nullptr, .len = 0}; + auto prehandle_result = kvs.preHandleSnapshotToFiles(new_region, snaps, 20, RAFT_INIT_LOG_TERM, std::nullopt, tmt); + return {new_region, prehandle_result}; +} + +// A unit test that Region being removed/created concurrently +TEST_F(RegionKVStoreTest, RegionTableBeingRecreated) +try +{ + LoggerPtr log = Logger::get(); + auto & ctx = TiFlashTestEnv::getGlobalContext(); + auto & tmt = ctx.getTMTContext(); + initStorages(); + KVStore & kvs = getKVS(); + ctx.getTMTContext().debugSetKVStore(kvstore); + const auto table_id = proxy_instance->bootstrapTable(ctx, kvs, ctx.getTMTContext()); + + auto & region_table = ctx.getTMTContext().getRegionTable(); + auto get_start_end = [&](Int64 range_beg, Int64 range_end) { + return std::make_pair(RecordKVFormat::genKey(table_id, range_beg), RecordKVFormat::genKey(table_id, range_end)); + }; + + const RegionID region_id_0 = 7000; + { + // Step 1: Generate a Region with region_id = 7000, range[100, 200) + LOG_INFO(log, "Step 1"); + root_of_kvstore_mem_trackers->reset(); + region_table.debugClearTableRegionSize(NullspaceID, table_id); + auto [start, end] = get_start_end(100, 200); + proxy_instance->debugAddRegions(kvs, tmt, {region_id_0}, {{start, end}}); + + root_of_kvstore_mem_trackers->reset(); + + RegionPtr region = kvs.getRegion(region_id_0); + auto str_key = RecordKVFormat::genKey(table_id, 105, 111); + auto [str_val_write, str_val_default] = proxy_instance->generateTiKVKeyValue(111, 999); + auto str_lock_value + = RecordKVFormat::encodeLockCfValue(RecordKVFormat::CFModifyFlag::PutFlag, "PK", region_id_0, 999) + .toString(); + region->insertDebug("default", TiKVKey::copyFrom(str_key), TiKVValue::copyFrom(str_val_default)); + ASSERT_EQ(root_of_kvstore_mem_trackers->get(), str_key.dataSize() + str_val_default.size()); + tryPersistRegion(kvs, region_id_0); + LOG_INFO(log, "Step 1: table_size: {}", region->getRegionTableSize()); + } + + RegionID pre_handle_region_id_1 = 7001; + RegionID pre_handle_region_id_2 = 7002; + { + // Step 2: Mock that Region with region_id = 7001, range[200, 300) being pre-handle + LOG_INFO(log, "Step 2"); + auto && [new_region_1, pre_handle_res_1] + = genPreHandlingRegion(kvs, pre_handle_region_id_1, table_id, 200, 300, region_table, tmt); + ASSERT_NE(new_region_1, nullptr); + + { + // check that Region 7000 and pre-handle Region 7001 should share the same table_ctx + auto region_0 = kvs.getRegion(region_id_0); + ASSERT_NE(region_0, nullptr); + ASSERT_EQ(region_0->getRegionTableSize(), new_region_1->getRegionTableSize()); + ASSERT_EQ(region_0->getRegionTableCtx().get(), new_region_1->getRegionTableCtx().get()); + LOG_INFO(log, "Step 2: table_size: {}", region_0->getRegionTableSize()); + } + + { + // Step 3: Mock that Region with region_id = 7000 is removed + // (This could lead to the RegionTable::Table instance being removed + // because no known Region belong to the table_id). + LOG_INFO(log, "Step 3"); + kvs.removeRegion( + region_id_0, + /*remove_data*/ true, + region_table, + kvs.genTaskLock(), + kvs.region_manager.genRegionTaskLock(region_id_0)); + } + + // Step 4: Mock that Region with region_id = 7002, range[300, 400) being pre-handle + // after region 7000 is removed. + LOG_INFO(log, "Step 4"); + auto && [new_region_2, pre_handle_res_2] + = genPreHandlingRegion(kvs, pre_handle_region_id_2, table_id, 300, 400, region_table, tmt); + ASSERT_NE(new_region_2, nullptr); + // Step 5: apply the pre-handle-region-1 and pre-handle-region-2 + kvs.applyPreHandledSnapshot( + RegionPtrWithSnapshotFiles{new_region_2, std::move(pre_handle_res_2.ingest_ids)}, + tmt); + kvs.applyPreHandledSnapshot( + RegionPtrWithSnapshotFiles{new_region_1, std::move(pre_handle_res_1.ingest_ids)}, + tmt); + { + // check that Region 7001 and Region 7002 should share the same table_ctx + ASSERT_EQ(new_region_2->getRegionTableSize(), new_region_1->getRegionTableSize()); + ASSERT_EQ(new_region_2->getRegionTableCtx().get(), new_region_1->getRegionTableCtx().get()); + LOG_INFO(log, "Step 5: table_size: {}", new_region_2->getRegionTableSize()); + } + } + RegionID region_id_3 = 7003; + { + // Step 6: Mock that new Region is added after all + LOG_INFO(log, "Step 5"); + auto [start, end] = get_start_end(300, 400); + proxy_instance->debugAddRegions(kvs, tmt, {region_id_3}, {{start, end}}); + + { + // check that Region 7001, 7002, 7003 should share the same table_ctx + auto pre_handle_region_1 = kvs.getRegion(pre_handle_region_id_1); + auto pre_handle_region_2 = kvs.getRegion(pre_handle_region_id_2); + ASSERT_EQ(pre_handle_region_1->getRegionTableSize(), pre_handle_region_2->getRegionTableSize()); + ASSERT_EQ(pre_handle_region_1->getRegionTableCtx().get(), pre_handle_region_2->getRegionTableCtx().get()); + auto region_3 = kvs.getRegion(region_id_3); + ASSERT_EQ(pre_handle_region_1->getRegionTableSize(), region_3->getRegionTableSize()); + ASSERT_EQ(pre_handle_region_1->getRegionTableCtx().get(), region_3->getRegionTableCtx().get()); + LOG_INFO(log, "Step 6: table_size: {}", region_3->getRegionTableSize()); + } + } + { + // Step 7: If we insert some data into Region 7003, the size of table_ctx of all Region should be updated + RegionPtr region = kvs.getRegion(region_id_3); + auto str_key = RecordKVFormat::genKey(table_id, 105, 120); + auto [str_val_write, str_val_default] = proxy_instance->generateTiKVKeyValue(120, 999); + auto str_lock_value + = RecordKVFormat::encodeLockCfValue(RecordKVFormat::CFModifyFlag::PutFlag, "PK", region_id_0, 999) + .toString(); + region->insertDebug("default", TiKVKey::copyFrom(str_key), TiKVValue::copyFrom(str_val_default)); + LOG_INFO(log, "Step 7: table_size: {}", region->getRegionTableSize()); + ASSERT_EQ(root_of_kvstore_mem_trackers->get(), str_key.dataSize() + str_val_default.size()); + { + // check that Region 7001, 7002, 7003 should share the same table_ctx + auto region_1 = kvs.getRegion(pre_handle_region_id_1); + auto region_2 = kvs.getRegion(pre_handle_region_id_2); + auto region_3 = kvs.getRegion(region_id_3); + auto region_tbl_size = region_1->getRegionTableSize(); + ASSERT_EQ(region_tbl_size, region_2->getRegionTableSize()); + ASSERT_EQ(region_tbl_size, region_3->getRegionTableSize()); + } + } +} +CATCH + #if USE_JEMALLOC // following tests depends on jemalloc TEST(FFIJemallocTest, JemallocThread) try diff --git a/dbms/src/Storages/KVStore/tests/gtest_new_kvstore.cpp b/dbms/src/Storages/KVStore/tests/gtest_new_kvstore.cpp index 19a2d7c53f9..09760f98317 100644 --- a/dbms/src/Storages/KVStore/tests/gtest_new_kvstore.cpp +++ b/dbms/src/Storages/KVStore/tests/gtest_new_kvstore.cpp @@ -76,13 +76,13 @@ TEST_F(RegionKVStoreTest, KVStoreFailRecovery) try { auto & ctx = TiFlashTestEnv::getGlobalContext(); - KVStore & kvs = getKVS(); { auto applied_index = 0; auto region_id = 1; { MockRaftStoreProxy::FailCond cond; + KVStore & kvs = getKVS(); proxy_instance->debugAddRegions( kvs, ctx.getTMTContext(), @@ -118,6 +118,7 @@ try auto applied_index = 0; auto region_id = 2; { + KVStore & kvs = getKVS(); proxy_instance->debugAddRegions( kvs, ctx.getTMTContext(), @@ -168,6 +169,7 @@ try auto applied_index = 0; auto region_id = 3; { + KVStore & kvs = getKVS(); proxy_instance->debugAddRegions( kvs, ctx.getTMTContext(), @@ -209,6 +211,7 @@ try auto applied_index = 0; auto region_id = 4; { + KVStore & kvs = getKVS(); proxy_instance->debugAddRegions( kvs, ctx.getTMTContext(), @@ -257,37 +260,32 @@ TEST_F(RegionKVStoreTest, KVStoreInvalidWrites) try { auto & ctx = TiFlashTestEnv::getGlobalContext(); - { - auto region_id = 1; - { - initStorages(); - KVStore & kvs = getKVS(); - proxy_instance->bootstrapTable(ctx, kvs, ctx.getTMTContext()); - proxy_instance->bootstrapWithRegion(kvs, ctx.getTMTContext(), region_id, std::nullopt); + auto region_id = 1; + initStorages(); + KVStore & kvs = getKVS(); + proxy_instance->bootstrapTable(ctx, kvs, ctx.getTMTContext()); + proxy_instance->bootstrapWithRegion(kvs, ctx.getTMTContext(), region_id, std::nullopt); - MockRaftStoreProxy::FailCond cond; + MockRaftStoreProxy::FailCond cond; - auto kvr1 = kvs.getRegion(region_id); - auto r1 = proxy_instance->getRegion(region_id); - ASSERT_NE(r1, nullptr); - ASSERT_NE(kvr1, nullptr); - ASSERT_EQ(r1->getLatestAppliedIndex(), kvr1->appliedIndex()); - { - r1->getLatestAppliedIndex(); - // This key has empty PK which is actually truncated. - std::string k = "7480000000000001FFBD5F720000000000FAF9ECEFDC3207FFFC"; - std::string v = "4486809092ACFEC38906"; - auto str_key = Redact::hexStringToKey(k.data(), k.size()); - auto str_val = Redact::hexStringToKey(v.data(), v.size()); - - auto [index, term] - = proxy_instance - ->rawWrite(region_id, {str_key}, {str_val}, {WriteCmdType::Put}, {ColumnFamilyType::Write}); - EXPECT_THROW(proxy_instance->doApply(kvs, ctx.getTMTContext(), cond, region_id, index), Exception); - UNUSED(term); - EXPECT_THROW(ReadRegionCommitCache(kvr1, true), Exception); - } - } + auto kvr1 = kvs.getRegion(region_id); + auto r1 = proxy_instance->getRegion(region_id); + ASSERT_NE(r1, nullptr); + ASSERT_NE(kvr1, nullptr); + ASSERT_EQ(r1->getLatestAppliedIndex(), kvr1->appliedIndex()); + { + r1->getLatestAppliedIndex(); + // This key has empty PK which is actually truncated. + std::string k = "7480000000000001FFBD5F720000000000FAF9ECEFDC3207FFFC"; + std::string v = "4486809092ACFEC38906"; + auto str_key = Redact::hexStringToKey(k.data(), k.size()); + auto str_val = Redact::hexStringToKey(v.data(), v.size()); + + auto [index, term] + = proxy_instance->rawWrite(region_id, {str_key}, {str_val}, {WriteCmdType::Put}, {ColumnFamilyType::Write}); + EXPECT_THROW(proxy_instance->doApply(kvs, ctx.getTMTContext(), cond, region_id, index), Exception); + UNUSED(term); + EXPECT_THROW(ReadRegionCommitCache(kvr1, true), Exception); } } CATCH diff --git a/dbms/src/Storages/KVStore/tests/gtest_proxy_state_machine.cpp b/dbms/src/Storages/KVStore/tests/gtest_proxy_state_machine.cpp index d24599c7394..7f44eb68f7b 100644 --- a/dbms/src/Storages/KVStore/tests/gtest_proxy_state_machine.cpp +++ b/dbms/src/Storages/KVStore/tests/gtest_proxy_state_machine.cpp @@ -17,6 +17,8 @@ #include #include +#include + // TODO: Move ServerInfo into KVStore, to make it more conhensive. namespace DB { diff --git a/dbms/src/Storages/KVStore/tests/gtest_region_block_reader.cpp b/dbms/src/Storages/KVStore/tests/gtest_region_block_reader.cpp index 174b19866bc..2c6f67f67ce 100644 --- a/dbms/src/Storages/KVStore/tests/gtest_region_block_reader.cpp +++ b/dbms/src/Storages/KVStore/tests/gtest_region_block_reader.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -22,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -645,9 +645,10 @@ try NullspaceID); RegionID region_id = 4; + // the start_key and end_key for table_id = 68 String region_start_key(bytesFromHexString("7480000000000000FF445F720000000000FA")); String region_end_key(bytesFromHexString("7480000000000000FF4500000000000000F8")); - auto region = makeRegion(region_id, region_start_key, region_end_key); + auto region = RegionBench::makeRegionForRange(region_id, region_start_key, region_end_key); // the hex kv dump from SSTFile std::vector> kvs = { {"7480000000000000FF4D5F728000000000FF0000010000000000FAF9F3125EFCF3FFFE", "4C8280809290B4BB8606"}, diff --git a/dbms/src/Storages/KVStore/tests/gtest_region_persister.cpp b/dbms/src/Storages/KVStore/tests/gtest_region_persister.cpp index ace2d3c5bcf..59698f77dd5 100644 --- a/dbms/src/Storages/KVStore/tests/gtest_region_persister.cpp +++ b/dbms/src/Storages/KVStore/tests/gtest_region_persister.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -25,7 +26,7 @@ #include #include #include -#include +#include #include #include #include @@ -70,10 +71,19 @@ static ::testing::AssertionResult RegionCompare( } #define ASSERT_REGION_EQ(val1, val2) ASSERT_PRED_FORMAT2(::DB::tests::RegionCompare, val1, val2) -static RegionPtr makeTmpRegion() + +using namespace RegionBench; +namespace +{ +RegionMeta createRegionMeta(UInt64 id, DB::TableID table_id) { - return makeRegion(createRegionMeta(1001, 1)); + auto meta = RegionBench::createMetaRegion(id, table_id, 0, 300); + return RegionMeta( + /*peer=*/RegionBench::createPeer(31, true), + /*region=*/meta, + /*apply_state_=*/initialApplyState()); } +} // namespace static std::function mockSerFactory(int value) { @@ -140,13 +150,18 @@ class RegionSeriTest : public ::testing::Test void clearFileOnDisk() { TiFlashTestEnv::tryRemovePath(dir_path, /*recreate=*/true); } + static RegionPtr makeTmpRegion(RegionID region_id = 1001, TableID table_id = 1) + { + return makeRegion(createRegionMeta(region_id, table_id)); + } + const std::string dir_path; }; TEST_F(RegionSeriTest, peer) try { - auto peer = createPeer(100, true); + auto peer = RegionBench::createPeer(100, true); const auto path = dir_path + "/peer.test"; WriteBufferFromFile write_buf(path, DBMS_DEFAULT_BUFFER_SIZE, O_WRONLY | O_CREAT); auto size = writeBinary2(peer, write_buf); @@ -163,7 +178,13 @@ CATCH TEST_F(RegionSeriTest, RegionInfo) try { - auto region_info = createRegionInfo(233, "", ""); + { + // Test create RegionMeta. + RegionMeta meta(createPeer(2, true), RegionBench::createMetaRegion(666, 0, 0, 1000), initialApplyState()); + ASSERT_EQ(meta.peerId(), 2); + } + + auto region_info = RegionBench::createMetaRegion(233, 66, 0, 200); const auto path = dir_path + "/region_info.test"; WriteBufferFromFile write_buf(path, DBMS_DEFAULT_BUFFER_SIZE, O_WRONLY | O_CREAT); auto size = writeBinary2(region_info, write_buf); @@ -273,14 +294,13 @@ try apply_state.mutable_truncated_state()->set_index(6672); apply_state.mutable_truncated_state()->set_term(6673); - *region_state.mutable_region() - = createRegionInfo(1001, RecordKVFormat::genKey(table_id, 0), RecordKVFormat::genKey(table_id, 300)); + *region_state.mutable_region() = RegionBench::createMetaRegion(1001, table_id, 0, 300); region_state.mutable_merge_state()->set_commit(888); region_state.mutable_merge_state()->set_min_index(777); *region_state.mutable_merge_state()->mutable_target() - = createRegionInfo(1111, RecordKVFormat::genKey(table_id, 300), RecordKVFormat::genKey(table_id, 400)); + = RegionBench::createMetaRegion(1111, table_id, 300, 400); } - region = makeRegion(RegionMeta(createPeer(31, true), apply_state, 5, region_state)); + region = makeRegion(RegionMeta(RegionBench::createPeer(31, true), apply_state, 5, region_state)); } TiKVKey key = RecordKVFormat::genKey(table_id, 323, 9983); diff --git a/dbms/src/Storages/KVStore/tests/gtest_spill.cpp b/dbms/src/Storages/KVStore/tests/gtest_spill.cpp index 19ebfb7d8a7..9449d82748d 100644 --- a/dbms/src/Storages/KVStore/tests/gtest_spill.cpp +++ b/dbms/src/Storages/KVStore/tests/gtest_spill.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include diff --git a/dbms/src/Storages/KVStore/tests/gtest_sync_schema.cpp b/dbms/src/Storages/KVStore/tests/gtest_sync_schema.cpp index 7ef8725ccf6..e1f9b10d183 100644 --- a/dbms/src/Storages/KVStore/tests/gtest_sync_schema.cpp +++ b/dbms/src/Storages/KVStore/tests/gtest_sync_schema.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -21,7 +22,6 @@ #include #include #include -#include #include #include #include diff --git a/dbms/src/Storages/KVStore/tests/gtest_sync_status.cpp b/dbms/src/Storages/KVStore/tests/gtest_sync_status.cpp index 62f11abc2f8..ce6b6398817 100644 --- a/dbms/src/Storages/KVStore/tests/gtest_sync_status.cpp +++ b/dbms/src/Storages/KVStore/tests/gtest_sync_status.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -29,7 +30,6 @@ #include #include #include -#include #include #include #include @@ -183,8 +183,7 @@ void createRegions(size_t region_num, TableID table_id) auto & tmt = TiFlashTestEnv::getContext()->getTMTContext(); for (size_t i = 0; i < region_num; i++) { - auto region - = makeRegion(i, RecordKVFormat::genKey(table_id, i), RecordKVFormat::genKey(table_id, i + region_num + 10)); + auto region = RegionBench::makeRegionForTable(i, table_id, i, i + region_num + 10); tmt.getRegionTable().shrinkRegionRange(*region); } } diff --git a/dbms/src/Storages/KVStore/tests/gtest_tikv_keyvalue.cpp b/dbms/src/Storages/KVStore/tests/gtest_tikv_keyvalue.cpp index 17ad8610372..37dceaefeca 100644 --- a/dbms/src/Storages/KVStore/tests/gtest_tikv_keyvalue.cpp +++ b/dbms/src/Storages/KVStore/tests/gtest_tikv_keyvalue.cpp @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include #include #include #include -#include #include #include diff --git a/dbms/src/Storages/KVStore/tests/kvstore_helper.h b/dbms/src/Storages/KVStore/tests/kvstore_helper.h index af7e3102ec1..c71c42fdae3 100644 --- a/dbms/src/Storages/KVStore/tests/kvstore_helper.h +++ b/dbms/src/Storages/KVStore/tests/kvstore_helper.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -32,7 +33,6 @@ #include #include #include -#include #include #include #include @@ -61,12 +61,6 @@ extern const char pause_passive_flush_before_persist_region[]; extern const char force_set_parallel_prehandle_threshold[]; } // namespace FailPoints -namespace RegionBench -{ -extern void setupPutRequest(raft_cmdpb::Request *, const std::string &, const TiKVKey &, const TiKVValue &); -extern void setupDelRequest(raft_cmdpb::Request *, const std::string &, const TiKVKey &); -} // namespace RegionBench - extern void CheckRegionForMergeCmd(const raft_cmdpb::AdminResponse & response, const RegionState & region_state); extern void ChangeRegionStateRange( RegionState & region_state, diff --git a/dbms/src/Storages/Page/V3/PageDirectory.cpp b/dbms/src/Storages/Page/V3/PageDirectory.cpp index 113c1b99ddb..a98d0686261 100644 --- a/dbms/src/Storages/Page/V3/PageDirectory.cpp +++ b/dbms/src/Storages/Page/V3/PageDirectory.cpp @@ -1644,15 +1644,17 @@ std::unordered_set PageDirectory::apply(PageEntriesEdit && edit, Stopwatch watch; std::unique_lock apply_lock(apply_mutex); - GET_METRIC(tiflash_storage_page_write_duration_seconds, type_latch).Observe(watch.elapsedSeconds()); watch.restart(); - writers.push_back(&w); + writers.push_back(&w); // push to write pipeline queue SYNC_FOR("after_PageDirectory::enter_write_group"); + // wait until becoming the write group owner or finished by the write group owner w.cv.wait(apply_lock, [&] { return w.done || &w == writers.front(); }); GET_METRIC(tiflash_storage_page_write_duration_seconds, type_wait_in_group).Observe(watch.elapsedSeconds()); watch.restart(); + + // finished by the write group owner if (w.done) { if (unlikely(!w.success)) @@ -1670,8 +1672,11 @@ std::unordered_set PageDirectory::apply(PageEntriesEdit && edit, // group owner, others just return an empty set. return {}; } + + /// This thread now is the write group owner, build the group. It will merge the + /// edits from `writers` to the owner's edit. auto * last_writer = buildWriteGroup(&w, apply_lock); - apply_lock.unlock(); + apply_lock.unlock(); // release the lock so that coming write could enter the pipeline queue SYNC_FOR("before_PageDirectory::leader_apply"); // `true` means the write process has completed without exception @@ -1680,6 +1685,7 @@ std::unordered_set PageDirectory::apply(PageEntriesEdit && edit, SCOPE_EXIT({ apply_lock.lock(); + // The write group owner exits, pop all finished `write` in the `writers` while (true) { auto * ready = writers.front(); @@ -1697,16 +1703,17 @@ std::unordered_set PageDirectory::apply(PageEntriesEdit && edit, if (ready == last_writer) break; } + // Try to wakeup next write as the owner if (!writers.empty()) { writers.front()->cv.notify_one(); } }); + /// Persist the write group owner's changes to WAL UInt64 max_sequence = sequence.load(); const auto edit_size = edit.size(); - // stage 1, persisted the changes to WAL. // In order to handle {put X, ref Y->X, del X} inside one WriteBatch (and // in later batch pipeline), we increase the sequence for each record. for (auto & r : edit.getMutRecords()) @@ -1718,6 +1725,8 @@ std::unordered_set PageDirectory::apply(PageEntriesEdit && edit, wal->apply(Trait::Serializer::serializeTo(edit), write_limiter); GET_METRIC(tiflash_storage_page_write_duration_seconds, type_wal).Observe(watch.elapsedSeconds()); watch.restart(); + + /// Commit the changes in memory's MVCC table SCOPE_EXIT({ // GET_METRIC(tiflash_storage_page_write_duration_seconds, type_commit).Observe(watch.elapsedSeconds()); }); @@ -1727,7 +1736,7 @@ std::unordered_set PageDirectory::apply(PageEntriesEdit && edit, { std::unique_lock table_lock(table_rw_mutex); - // stage 2, create entry version list for page_id. + // create entry version list for page_id. for (const auto & r : edit.getRecords()) { // Protected in write_lock diff --git a/dbms/src/Storages/Page/V3/Universal/UniversalPageIdFormatImpl.h b/dbms/src/Storages/Page/V3/Universal/UniversalPageIdFormatImpl.h index fc3deea49c3..3a6b434e874 100644 --- a/dbms/src/Storages/Page/V3/Universal/UniversalPageIdFormatImpl.h +++ b/dbms/src/Storages/Page/V3/Universal/UniversalPageIdFormatImpl.h @@ -51,14 +51,14 @@ namespace DB // // Storage key // Meta -// Prefix = [optional KeyspaceID] + "tm" + NamespaceID +// Prefix = [optional KeyspaceID] + "tm"(0x746D) + NamespaceID // Log -// Prefix = [optional KeyspaceID] + "tl" + NamespaceID +// Prefix = [optional KeyspaceID] + "tl"(0x746C) + NamespaceID // Data -// Prefix = [optional KeyspaceID] + "td" + NamespaceID +// Prefix = [optional KeyspaceID] + "td"(0x7464) + NamespaceID // // KeyspaceID format is the same as in https://github.com/tikv/rfcs/blob/master/text/0069-api-v2.md -// 'x'(TXN_MODE_PREFIX) + keyspace_id(3 bytes, big endian) +// 'x'(TXN_MODE_PREFIX, which is 0x78) + keyspace_id(3 bytes, big endian) // Note 'x' will be a reserved keyword, and should not be used in other prefix. // If the first byte of a UniversalPageId is 'x', the next 3 bytes will be considered a KeyspaceID. // If not, NullspaceID will be returned. diff --git a/dbms/src/Storages/Page/tools/PageCtl/PageStorageCtlV3.cpp b/dbms/src/Storages/Page/tools/PageCtl/PageStorageCtlV3.cpp index 174dcd504b0..df8fcfba607 100644 --- a/dbms/src/Storages/Page/tools/PageCtl/PageStorageCtlV3.cpp +++ b/dbms/src/Storages/Page/tools/PageCtl/PageStorageCtlV3.cpp @@ -32,9 +32,9 @@ #include #include #include +#include #include -#include #include #include #include diff --git a/dbms/src/Storages/Page/workload/PSStressEnv.cpp b/dbms/src/Storages/Page/workload/PSStressEnv.cpp index 11df93d5029..4320f85efec 100644 --- a/dbms/src/Storages/Page/workload/PSStressEnv.cpp +++ b/dbms/src/Storages/Page/workload/PSStressEnv.cpp @@ -25,10 +25,9 @@ #include #include #include +#include #include -#include - namespace DB::PS::tests { LoggerPtr StressEnv::buildLogger(bool enable_color) diff --git a/dbms/src/Storages/S3/FileCache.cpp b/dbms/src/Storages/S3/FileCache.cpp index c47797d86a5..efa4e8f1a0c 100644 --- a/dbms/src/Storages/S3/FileCache.cpp +++ b/dbms/src/Storages/S3/FileCache.cpp @@ -642,6 +642,10 @@ FileType FileCache::getFileType(const String & fname) { return FileType::VectorIndex; } + else if (ext == ".inverted") + { + return FileType::InvertedIndex; + } else if (ext == ".meta") { // Example: v1.meta diff --git a/dbms/src/Storages/S3/FileCache.h b/dbms/src/Storages/S3/FileCache.h index 971cdc3f10d..38007b9ac6d 100644 --- a/dbms/src/Storages/S3/FileCache.h +++ b/dbms/src/Storages/S3/FileCache.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -53,6 +54,7 @@ class FileSegment // which must be downloaded to the local disk. // So the priority of caching is relatively high VectorIndex, + InvertedIndex, Merged, Index, Mark, // .mkr, .null.mrk @@ -228,7 +230,12 @@ class FileCache : nullptr; } - static void shutdown() { global_file_cache_instance = nullptr; } + static void shutdown() + { + // wait for all tasks done + S3FileCachePool::shutdown(); + global_file_cache_instance = nullptr; + } FileCache(PathCapacityMetricsPtr capacity_metrics_, const StorageRemoteCacheConfig & config_); @@ -299,6 +306,7 @@ class FileCache 0, // Unknow type, currently never cache it. 8 * 1024, // Estimated size of meta. 12 * 1024 * 1024, // Estimated size of vector index + 128 * 1024, // Estimated size of inverted index. 1 * 1024 * 1024, // Estimated size of merged. 8 * 1024, // Estimated size of index. 8 * 1024, // Estimated size of mark. diff --git a/dbms/src/Storages/S3/tests/gtest_filecache.cpp b/dbms/src/Storages/S3/tests/gtest_filecache.cpp index 5dca3b79dce..1eca339bdbc 100644 --- a/dbms/src/Storages/S3/tests/gtest_filecache.cpp +++ b/dbms/src/Storages/S3/tests/gtest_filecache.cpp @@ -448,191 +448,23 @@ try auto unknow_fname1 = fmt::format("{}/123456.lock", s3_fname); ASSERT_EQ(FileCache::getFileType(unknow_fname1), FileType::Unknow); + for (UInt64 level = 0; level <= magic_enum::enum_count(); ++level) { - UInt64 level = 0; auto cache_dir = fmt::format("{}/filetype{}", tmp_dir, level); StorageRemoteCacheConfig cache_config{.dir = cache_dir, .capacity = cache_capacity, .dtfile_level = level}; FileCache file_cache(capacity_metrics, cache_config); ASSERT_FALSE(file_cache.canCache(FileType::Unknow)); - ASSERT_FALSE(file_cache.canCache(FileType::Meta)); - ASSERT_FALSE(file_cache.canCache(FileType::Merged)); - ASSERT_FALSE(file_cache.canCache(FileType::Index)); - ASSERT_FALSE(file_cache.canCache(FileType::Mark)); - ASSERT_FALSE(file_cache.canCache(FileType::NullMap)); - ASSERT_FALSE(file_cache.canCache(FileType::DeleteMarkColData)); - ASSERT_FALSE(file_cache.canCache(FileType::VersionColData)); - ASSERT_FALSE(file_cache.canCache(FileType::HandleColData)); - ASSERT_FALSE(file_cache.canCache(FileType::ColData)); - } - { - UInt64 level = 1; - auto cache_dir = fmt::format("{}/filetype{}", tmp_dir, level); - StorageRemoteCacheConfig cache_config{.dir = cache_dir, .capacity = cache_capacity, .dtfile_level = level}; - FileCache file_cache(capacity_metrics, cache_config); - ASSERT_FALSE(file_cache.canCache(FileType::Unknow)); - ASSERT_TRUE(file_cache.canCache(FileType::Meta)); - ASSERT_FALSE(file_cache.canCache(FileType::VectorIndex)); - ASSERT_FALSE(file_cache.canCache(FileType::Merged)); - ASSERT_FALSE(file_cache.canCache(FileType::Index)); - ASSERT_FALSE(file_cache.canCache(FileType::Mark)); - ASSERT_FALSE(file_cache.canCache(FileType::NullMap)); - ASSERT_FALSE(file_cache.canCache(FileType::DeleteMarkColData)); - ASSERT_FALSE(file_cache.canCache(FileType::VersionColData)); - ASSERT_FALSE(file_cache.canCache(FileType::HandleColData)); - ASSERT_FALSE(file_cache.canCache(FileType::ColData)); - } - { - UInt64 level = 2; - auto cache_dir = fmt::format("{}/filetype{}", tmp_dir, level); - StorageRemoteCacheConfig cache_config{.dir = cache_dir, .capacity = cache_capacity, .dtfile_level = level}; - FileCache file_cache(capacity_metrics, cache_config); - ASSERT_FALSE(file_cache.canCache(FileType::Unknow)); - ASSERT_TRUE(file_cache.canCache(FileType::Meta)); - ASSERT_TRUE(file_cache.canCache(FileType::VectorIndex)); - ASSERT_FALSE(file_cache.canCache(FileType::Merged)); - ASSERT_FALSE(file_cache.canCache(FileType::Index)); - ASSERT_FALSE(file_cache.canCache(FileType::Mark)); - ASSERT_FALSE(file_cache.canCache(FileType::NullMap)); - ASSERT_FALSE(file_cache.canCache(FileType::DeleteMarkColData)); - ASSERT_FALSE(file_cache.canCache(FileType::VersionColData)); - ASSERT_FALSE(file_cache.canCache(FileType::HandleColData)); - ASSERT_FALSE(file_cache.canCache(FileType::ColData)); - } - { - UInt64 level = 3; - auto cache_dir = fmt::format("{}/filetype{}", tmp_dir, level); - StorageRemoteCacheConfig cache_config{.dir = cache_dir, .capacity = cache_capacity, .dtfile_level = level}; - FileCache file_cache(capacity_metrics, cache_config); - ASSERT_FALSE(file_cache.canCache(FileType::Unknow)); - ASSERT_TRUE(file_cache.canCache(FileType::Meta)); - ASSERT_TRUE(file_cache.canCache(FileType::VectorIndex)); - ASSERT_TRUE(file_cache.canCache(FileType::Merged)); - ASSERT_FALSE(file_cache.canCache(FileType::Index)); - ASSERT_FALSE(file_cache.canCache(FileType::Mark)); - ASSERT_FALSE(file_cache.canCache(FileType::NullMap)); - ASSERT_FALSE(file_cache.canCache(FileType::DeleteMarkColData)); - ASSERT_FALSE(file_cache.canCache(FileType::VersionColData)); - ASSERT_FALSE(file_cache.canCache(FileType::HandleColData)); - ASSERT_FALSE(file_cache.canCache(FileType::ColData)); - } - { - UInt64 level = 4; - auto cache_dir = fmt::format("{}/filetype{}", tmp_dir, level); - StorageRemoteCacheConfig cache_config{.dir = cache_dir, .capacity = cache_capacity, .dtfile_level = level}; - FileCache file_cache(capacity_metrics, cache_config); - ASSERT_FALSE(file_cache.canCache(FileType::Unknow)); - ASSERT_TRUE(file_cache.canCache(FileType::Meta)); - ASSERT_TRUE(file_cache.canCache(FileType::VectorIndex)); - ASSERT_TRUE(file_cache.canCache(FileType::Merged)); - ASSERT_TRUE(file_cache.canCache(FileType::Index)); - ASSERT_FALSE(file_cache.canCache(FileType::Mark)); - ASSERT_FALSE(file_cache.canCache(FileType::NullMap)); - ASSERT_FALSE(file_cache.canCache(FileType::DeleteMarkColData)); - ASSERT_FALSE(file_cache.canCache(FileType::VersionColData)); - ASSERT_FALSE(file_cache.canCache(FileType::HandleColData)); - ASSERT_FALSE(file_cache.canCache(FileType::ColData)); - } - { - UInt64 level = 5; - auto cache_dir = fmt::format("{}/filetype{}", tmp_dir, level); - StorageRemoteCacheConfig cache_config{.dir = cache_dir, .capacity = cache_capacity, .dtfile_level = level}; - FileCache file_cache(capacity_metrics, cache_config); - ASSERT_FALSE(file_cache.canCache(FileType::Unknow)); - ASSERT_TRUE(file_cache.canCache(FileType::Meta)); - ASSERT_TRUE(file_cache.canCache(FileType::VectorIndex)); - ASSERT_TRUE(file_cache.canCache(FileType::Merged)); - ASSERT_TRUE(file_cache.canCache(FileType::Index)); - ASSERT_TRUE(file_cache.canCache(FileType::Mark)); - ASSERT_FALSE(file_cache.canCache(FileType::NullMap)); - ASSERT_FALSE(file_cache.canCache(FileType::DeleteMarkColData)); - ASSERT_FALSE(file_cache.canCache(FileType::VersionColData)); - ASSERT_FALSE(file_cache.canCache(FileType::HandleColData)); - ASSERT_FALSE(file_cache.canCache(FileType::ColData)); - } - { - UInt64 level = 6; - auto cache_dir = fmt::format("{}/filetype{}", tmp_dir, level); - StorageRemoteCacheConfig cache_config{.dir = cache_dir, .capacity = cache_capacity, .dtfile_level = level}; - FileCache file_cache(capacity_metrics, cache_config); - ASSERT_FALSE(file_cache.canCache(FileType::Unknow)); - ASSERT_TRUE(file_cache.canCache(FileType::Meta)); - ASSERT_TRUE(file_cache.canCache(FileType::VectorIndex)); - ASSERT_TRUE(file_cache.canCache(FileType::Merged)); - ASSERT_TRUE(file_cache.canCache(FileType::Index)); - ASSERT_TRUE(file_cache.canCache(FileType::Mark)); - ASSERT_TRUE(file_cache.canCache(FileType::NullMap)); - ASSERT_FALSE(file_cache.canCache(FileType::DeleteMarkColData)); - ASSERT_FALSE(file_cache.canCache(FileType::VersionColData)); - ASSERT_FALSE(file_cache.canCache(FileType::HandleColData)); - ASSERT_FALSE(file_cache.canCache(FileType::ColData)); - } - { - UInt64 level = 7; - auto cache_dir = fmt::format("{}/filetype{}", tmp_dir, level); - StorageRemoteCacheConfig cache_config{.dir = cache_dir, .capacity = cache_capacity, .dtfile_level = level}; - FileCache file_cache(capacity_metrics, cache_config); - ASSERT_FALSE(file_cache.canCache(FileType::Unknow)); - ASSERT_TRUE(file_cache.canCache(FileType::Meta)); - ASSERT_TRUE(file_cache.canCache(FileType::VectorIndex)); - ASSERT_TRUE(file_cache.canCache(FileType::Merged)); - ASSERT_TRUE(file_cache.canCache(FileType::Index)); - ASSERT_TRUE(file_cache.canCache(FileType::Mark)); - ASSERT_TRUE(file_cache.canCache(FileType::NullMap)); - ASSERT_TRUE(file_cache.canCache(FileType::DeleteMarkColData)); - ASSERT_FALSE(file_cache.canCache(FileType::VersionColData)); - ASSERT_FALSE(file_cache.canCache(FileType::HandleColData)); - ASSERT_FALSE(file_cache.canCache(FileType::ColData)); - } - { - UInt64 level = 8; - auto cache_dir = fmt::format("{}/filetype{}", tmp_dir, level); - StorageRemoteCacheConfig cache_config{.dir = cache_dir, .capacity = cache_capacity, .dtfile_level = level}; - FileCache file_cache(capacity_metrics, cache_config); - ASSERT_FALSE(file_cache.canCache(FileType::Unknow)); - ASSERT_TRUE(file_cache.canCache(FileType::Meta)); - ASSERT_TRUE(file_cache.canCache(FileType::VectorIndex)); - ASSERT_TRUE(file_cache.canCache(FileType::Merged)); - ASSERT_TRUE(file_cache.canCache(FileType::Index)); - ASSERT_TRUE(file_cache.canCache(FileType::Mark)); - ASSERT_TRUE(file_cache.canCache(FileType::NullMap)); - ASSERT_TRUE(file_cache.canCache(FileType::DeleteMarkColData)); - ASSERT_TRUE(file_cache.canCache(FileType::VersionColData)); - ASSERT_FALSE(file_cache.canCache(FileType::HandleColData)); - ASSERT_FALSE(file_cache.canCache(FileType::ColData)); - } - { - UInt64 level = 9; - auto cache_dir = fmt::format("{}/filetype{}", tmp_dir, level); - StorageRemoteCacheConfig cache_config{.dir = cache_dir, .capacity = cache_capacity, .dtfile_level = level}; - FileCache file_cache(capacity_metrics, cache_config); - ASSERT_FALSE(file_cache.canCache(FileType::Unknow)); - ASSERT_TRUE(file_cache.canCache(FileType::Meta)); - ASSERT_TRUE(file_cache.canCache(FileType::VectorIndex)); - ASSERT_TRUE(file_cache.canCache(FileType::Merged)); - ASSERT_TRUE(file_cache.canCache(FileType::Index)); - ASSERT_TRUE(file_cache.canCache(FileType::Mark)); - ASSERT_TRUE(file_cache.canCache(FileType::NullMap)); - ASSERT_TRUE(file_cache.canCache(FileType::DeleteMarkColData)); - ASSERT_TRUE(file_cache.canCache(FileType::VersionColData)); - ASSERT_TRUE(file_cache.canCache(FileType::HandleColData)); - ASSERT_FALSE(file_cache.canCache(FileType::ColData)); - } - { - UInt64 level = 10; - auto cache_dir = fmt::format("{}/filetype{}", tmp_dir, level); - StorageRemoteCacheConfig cache_config{.dir = cache_dir, .capacity = cache_capacity, .dtfile_level = level}; - FileCache file_cache(capacity_metrics, cache_config); - ASSERT_FALSE(file_cache.canCache(FileType::Unknow)); - ASSERT_TRUE(file_cache.canCache(FileType::Meta)); - ASSERT_TRUE(file_cache.canCache(FileType::VectorIndex)); - ASSERT_TRUE(file_cache.canCache(FileType::Merged)); - ASSERT_TRUE(file_cache.canCache(FileType::Index)); - ASSERT_TRUE(file_cache.canCache(FileType::Mark)); - ASSERT_TRUE(file_cache.canCache(FileType::NullMap)); - ASSERT_TRUE(file_cache.canCache(FileType::DeleteMarkColData)); - ASSERT_TRUE(file_cache.canCache(FileType::VersionColData)); - ASSERT_TRUE(file_cache.canCache(FileType::HandleColData)); - ASSERT_TRUE(file_cache.canCache(FileType::ColData)); + ASSERT_EQ(file_cache.canCache(FileType::Meta), level >= 1); + ASSERT_EQ(file_cache.canCache(FileType::VectorIndex), level >= 2); + ASSERT_EQ(file_cache.canCache(FileType::InvertedIndex), level >= 3); + ASSERT_EQ(file_cache.canCache(FileType::Merged), level >= 4); + ASSERT_EQ(file_cache.canCache(FileType::Index), level >= 5); + ASSERT_EQ(file_cache.canCache(FileType::Mark), level >= 6); + ASSERT_EQ(file_cache.canCache(FileType::NullMap), level >= 7); + ASSERT_EQ(file_cache.canCache(FileType::DeleteMarkColData), level >= 8); + ASSERT_EQ(file_cache.canCache(FileType::VersionColData), level >= 9); + ASSERT_EQ(file_cache.canCache(FileType::HandleColData), level >= 10); + ASSERT_EQ(file_cache.canCache(FileType::ColData), level >= 11); } } CATCH diff --git a/dbms/src/Storages/SelectQueryInfo.h b/dbms/src/Storages/SelectQueryInfo.h index 7f8f2987ece..2b7dc2b898e 100644 --- a/dbms/src/Storages/SelectQueryInfo.h +++ b/dbms/src/Storages/SelectQueryInfo.h @@ -42,7 +42,7 @@ struct SelectQueryInfo { ASTPtr query; - /// Prepared sets are used for indices by storage engine. + /// Prepared sets are used for handling queries with `IN` section. /// Example: x IN (1, 2, 3) PreparedSets sets; @@ -61,7 +61,7 @@ struct SelectQueryInfo SelectQueryInfo(const SelectQueryInfo & rhs); SelectQueryInfo(SelectQueryInfo && rhs) noexcept; - bool fromAST() const { return dag_query == nullptr; }; + bool fromAST() const { return dag_query == nullptr; } }; } // namespace DB diff --git a/dbms/src/Storages/System/StorageSystemTables.cpp b/dbms/src/Storages/System/StorageSystemTables.cpp index 8d0e40f80ce..1fc7b7b8c61 100644 --- a/dbms/src/Storages/System/StorageSystemTables.cpp +++ b/dbms/src/Storages/System/StorageSystemTables.cpp @@ -282,31 +282,6 @@ BlockInputStreams StorageSystemTables::read( } } - if (context.hasSessionContext()) - { - Tables external_tables = context.getSessionContext().getExternalTables(); - - for (const auto & table : external_tables) - { - size_t j = 0; - res_columns[j++]->insertDefault(); - res_columns[j++]->insert(table.first); - res_columns[j++]->insert(table.second->getName()); - res_columns[j++]->insert(static_cast(1)); - res_columns[j++]->insertDefault(); - res_columns[j++]->insertDefault(); - - if (has_metadata_modification_time) - res_columns[j++]->insertDefault(); - - if (has_create_table_query) - res_columns[j++]->insertDefault(); - - if (has_engine_full) - res_columns[j++]->insert(table.second->getName()); - } - } - res_block.setColumns(std::move(res_columns)); return {std::make_shared(res_block)}; } diff --git a/dbms/src/TiDB/Schema/InvertedIndex.h b/dbms/src/TiDB/Schema/InvertedIndex.h new file mode 100644 index 00000000000..222464d4b94 --- /dev/null +++ b/dbms/src/TiDB/Schema/InvertedIndex.h @@ -0,0 +1,59 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace TiDB +{ + +// Constructed from table definition. +struct InvertedIndexDefinition +{ + bool is_signed; + UInt8 type_size; +}; + +// As this is constructed from TiDB's table definition, we should not +// ever try to modify it anyway. +using InvertedIndexDefinitionPtr = std::shared_ptr; +} // namespace TiDB + +template <> +struct fmt::formatter +{ + static constexpr auto parse(format_parse_context & ctx) { return ctx.begin(); } + + template + auto format(const TiDB::InvertedIndexDefinition & index, FormatContext & ctx) const -> decltype(ctx.out()) + { + return fmt::format_to(ctx.out(), "{}{}", index.is_signed ? "" : "U", index.type_size * 8); + } +}; + +template <> +struct fmt::formatter +{ + static constexpr auto parse(format_parse_context & ctx) { return ctx.begin(); } + + template + auto format(const TiDB::InvertedIndexDefinitionPtr & index, FormatContext & ctx) const -> decltype(ctx.out()) + { + if (!index) + return fmt::format_to(ctx.out(), ""); + return fmt::format_to(ctx.out(), "{}", *index); + } +}; diff --git a/dbms/src/TiDB/Schema/TiDB.h b/dbms/src/TiDB/Schema/TiDB.h index 89c5ca7938a..8344b84b143 100644 --- a/dbms/src/TiDB/Schema/TiDB.h +++ b/dbms/src/TiDB/Schema/TiDB.h @@ -249,6 +249,7 @@ enum class ColumnarIndexKind { // Leave 0 intentionally for InvalidValues Vector = 1, + Inverted = 2, }; struct IndexInfo diff --git a/libs/libcommon/CMakeLists.txt b/libs/libcommon/CMakeLists.txt index 90f1b20e700..3c2d9ddfb4a 100644 --- a/libs/libcommon/CMakeLists.txt +++ b/libs/libcommon/CMakeLists.txt @@ -66,7 +66,6 @@ add_library (common ${SPLIT_SHARED} include/common/mremap.h include/common/likely.h include/common/logger_useful.h - include/common/MultiVersion.h include/common/strong_typedef.h include/common/JSON.h include/common/simd.h diff --git a/libs/libcommon/include/common/MultiVersion.h b/libs/libcommon/include/common/MultiVersion.h deleted file mode 100644 index 05806d229be..00000000000 --- a/libs/libcommon/include/common/MultiVersion.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - - -/** Allow to store and read-only usage of an object in several threads, - * and to atomically replace an object in another thread. - * The replacement is atomic and reading threads can work with different versions of an object. - * - * Usage: - * MultiVersion x; - * - on data update: - * x.set(new value); - * - on read-only usage: - * { - * MultiVersion::Version current_version = x.get(); - * // use *current_version - * } // now we finish own current version; if the version is outdated and no one else is using it - it will be destroyed. - * - * All methods are thread-safe. - */ -template -class MultiVersion -{ -public: - /// Version of object for usage. shared_ptr manage lifetime of version. - using Version = std::shared_ptr; - - /// Default initialization - by nullptr. - MultiVersion() = default; - - MultiVersion(std::unique_ptr && value) { set(std::move(value)); } - - /// Obtain current version for read-only usage. Returns shared_ptr, that manages lifetime of version. - Version get() const - { - /// NOTE: is it possible to lock-free replace of shared_ptr? - std::lock_guard lock(mutex); - return current_version; - } - - /// Update an object with new version. - void set(std::unique_ptr && value) - { - std::lock_guard lock(mutex); - current_version = std::move(value); - } - -private: - Version current_version; - mutable std::mutex mutex; -}; diff --git a/libs/libcommon/include/common/iostream_debug_helpers.h b/libs/libcommon/include/common/iostream_debug_helpers.h deleted file mode 100644 index b92ac8ab351..00000000000 --- a/libs/libcommon/include/common/iostream_debug_helpers.h +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include -#include -#include - -/** Usage: - * - * DUMP(variable...) - */ - - -template -Out & dumpValue(Out &, T &&); - - -/// Catch-all case. -template -std::enable_if_t & dumpImpl(Out & out, T &&) -{ - return out << "{...}"; -} - -/// An object, that could be output with operator <<. -template -std::enable_if_t & dumpImpl( - Out & out, - T && x, - std::decay_t() << std::declval())> * = nullptr) -{ - return out << x; -} - -/// A pointer-like object. -template -std::enable_if_t< - priority == 1 - /// Protect from the case when operator * do effectively nothing (function pointer). - && !std::is_same_v, std::decay_t())>>, - Out> & -dumpImpl(Out & out, T && x, std::decay_t())> * = nullptr) -{ - if (!x) - return out << "nullptr"; - return dumpValue(out, *x); -} - -/// Container. -template -std::enable_if_t & dumpImpl( - Out & out, - T && x, - std::decay_t()))> * = nullptr) -{ - bool first = true; - out << "{"; - for (const auto & elem : x) - { - if (first) - first = false; - else - out << ", "; - dumpValue(out, elem); - } - return out << "}"; -} - - -/// string and const char * - output not as container or pointer. - -template -std::enable_if_t, std::string> || std::is_same_v, const char *>), Out> & dumpImpl( - Out & out, - T && x) -{ - return out << std::quoted(x); -} - -/// UInt8 - output as number, not char. - -template -std::enable_if_t, unsigned char>, Out> & dumpImpl(Out & out, T && x) -{ - return out << int(x); -} - - -/// Tuple, pair -template -Out & dumpTupleImpl(Out & out, T && x) -{ - if constexpr (N == 0) - out << "{"; - else - out << ", "; - - dumpValue(out, std::get(x)); - - if constexpr (N + 1 == std::tuple_size_v>) - out << "}"; - else - dumpTupleImpl(out, x); - - return out; -} - -template -std::enable_if_t & dumpImpl( - Out & out, - T && x, - std::decay_t(std::declval()))> * = nullptr) -{ - return dumpTupleImpl<0>(out, x); -} - - -template -Out & dumpDispatchPriorities( - Out & out, - T && x, - std::decay_t(std::declval(), std::declval()))> *) -{ - return dumpImpl(out, x); -} - -struct LowPriority -{ - LowPriority(void *) {} -}; - -template -Out & dumpDispatchPriorities(Out & out, T && x, LowPriority) -{ - return dumpDispatchPriorities(out, x, nullptr); -} - - -template -Out & dumpValue(Out & out, T && x) -{ - return dumpDispatchPriorities<5>(out, x, nullptr); -} - - -template -Out & dump(Out & out, const char * name, T && x) -{ - out << demangle(typeid(x).name()) << " " << name << " = "; - return dumpValue(out, x); -} - - -#define DUMPVAR(VAR) \ - dump(std::cerr, #VAR, (VAR)); \ - std::cerr << "; "; -#define DUMPHEAD std::cerr << __FILE__ << ':' << __LINE__ << " "; -#define DUMPTAIL std::cerr << '\n'; - -#define DUMP1(V1) \ - do \ - { \ - DUMPHEAD DUMPVAR(V1) DUMPTAIL \ - } while (0); -#define DUMP2(V1, V2) \ - do \ - { \ - DUMPHEAD DUMPVAR(V1) DUMPVAR(V2) DUMPTAIL \ - } while (0); -#define DUMP3(V1, V2, V3) \ - do \ - { \ - DUMPHEAD DUMPVAR(V1) DUMPVAR(V2) DUMPVAR(V3) DUMPTAIL \ - } while (0); -#define DUMP4(V1, V2, V3, V4) \ - do \ - { \ - DUMPHEAD DUMPVAR(V1) DUMPVAR(V2) DUMPVAR(V3) DUMPVAR(V4) DUMPTAIL \ - } while (0); -#define DUMP5(V1, V2, V3, V4, V5) \ - do \ - { \ - DUMPHEAD DUMPVAR(V1) DUMPVAR(V2) DUMPVAR(V3) DUMPVAR(V4) DUMPVAR(V5) DUMPTAIL \ - } while (0); - -/// https://groups.google.com/forum/#!searchin/kona-dev/variadic$20macro%7Csort:date/kona-dev/XMA-lDOqtlI/GCzdfZsD41sJ - -#define VA_NUM_ARGS_IMPL(x1, x2, x3, x4, x5, N, ...) N -#define VA_NUM_ARGS(...) VA_NUM_ARGS_IMPL(__VA_ARGS__, 5, 4, 3, 2, 1) - -#define MAKE_VAR_MACRO_IMPL_CONCAT(PREFIX, NUM_ARGS) PREFIX##NUM_ARGS -#define MAKE_VAR_MACRO_IMPL(PREFIX, NUM_ARGS) MAKE_VAR_MACRO_IMPL_CONCAT(PREFIX, NUM_ARGS) -#define MAKE_VAR_MACRO(PREFIX, ...) MAKE_VAR_MACRO_IMPL(PREFIX, VA_NUM_ARGS(__VA_ARGS__)) - -#define DUMP(...) MAKE_VAR_MACRO(DUMP, __VA_ARGS__)(__VA_ARGS__) diff --git a/libs/libcommon/src/tests/CMakeLists.txt b/libs/libcommon/src/tests/CMakeLists.txt index 15d57a2ad09..c5a2746777f 100644 --- a/libs/libcommon/src/tests/CMakeLists.txt +++ b/libs/libcommon/src/tests/CMakeLists.txt @@ -14,23 +14,6 @@ include (${TiFlash_SOURCE_DIR}/cmake/add_check.cmake) -add_executable (date_lut_init date_lut_init.cpp) -add_executable (date_lut2 date_lut2.cpp) -add_executable (date_lut3 date_lut3.cpp) -add_executable (date_lut4 date_lut4.cpp) -add_executable (date_lut_default_timezone date_lut_default_timezone.cpp) -add_executable (multi_version multi_version.cpp) - -set(PLATFORM_LIBS ${CMAKE_DL_LIBS}) - -target_link_libraries (date_lut_init common ${PLATFORM_LIBS}) -target_link_libraries (date_lut2 common ${PLATFORM_LIBS}) -target_link_libraries (date_lut3 common ${PLATFORM_LIBS}) -target_link_libraries (date_lut4 common ${PLATFORM_LIBS}) -target_link_libraries (date_lut_default_timezone common ${PLATFORM_LIBS}) -target_link_libraries (multi_version common) -add_check(multi_version) - add_executable (gtests_libcommon gtest_json_test.cpp gtest_strong_typedef.cpp diff --git a/libs/libcommon/src/tests/date_lut2.cpp b/libs/libcommon/src/tests/date_lut2.cpp deleted file mode 100644 index 94327d2afab..00000000000 --- a/libs/libcommon/src/tests/date_lut2.cpp +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include - - -static std::string toString(time_t Value) -{ - struct tm tm; - char buf[96]; - - localtime_r(&Value, &tm); - snprintf( - buf, - sizeof(buf), - "%04d-%02d-%02d %02d:%02d:%02d", - tm.tm_year + 1900, - tm.tm_mon + 1, - tm.tm_mday, - tm.tm_hour, - tm.tm_min, - tm.tm_sec); - - return buf; -} - -static time_t orderedIdentifierToDate(unsigned value) -{ - struct tm tm; - - memset(&tm, 0, sizeof(tm)); - - tm.tm_year = value / 10000 - 1900; - tm.tm_mon = (value % 10000) / 100 - 1; - tm.tm_mday = value % 100; - tm.tm_isdst = -1; - - return mktime(&tm); -} - - -void loop(time_t begin, time_t end, int step) -{ - const auto & date_lut = DateLUT::instance(); - - for (time_t t = begin; t < end; t += step) - std::cout << toString(t) << ", " << toString(date_lut.toTime(t)) << ", " << date_lut.toHour(t) << std::endl; -} - - -int main(int argc, char ** argv) -{ - loop(orderedIdentifierToDate(20101031), orderedIdentifierToDate(20101101), 15 * 60); - loop(orderedIdentifierToDate(20100328), orderedIdentifierToDate(20100330), 15 * 60); - loop(orderedIdentifierToDate(20141020), orderedIdentifierToDate(20141106), 15 * 60); - - return 0; -} diff --git a/libs/libcommon/src/tests/date_lut3.cpp b/libs/libcommon/src/tests/date_lut3.cpp deleted file mode 100644 index dd8698cf1da..00000000000 --- a/libs/libcommon/src/tests/date_lut3.cpp +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include -#include - - -static std::string toString(time_t Value) -{ - struct tm tm; - char buf[96]; - - localtime_r(&Value, &tm); - snprintf( - buf, - sizeof(buf), - "%04d-%02d-%02d %02d:%02d:%02d", - tm.tm_year + 1900, - tm.tm_mon + 1, - tm.tm_mday, - tm.tm_hour, - tm.tm_min, - tm.tm_sec); - - return buf; -} - -static time_t orderedIdentifierToDate(unsigned value) -{ - struct tm tm; - - memset(&tm, 0, sizeof(tm)); - - tm.tm_year = value / 10000 - 1900; - tm.tm_mon = (value % 10000) / 100 - 1; - tm.tm_mday = value % 100; - tm.tm_isdst = -1; - - return mktime(&tm); -} - - -void loop(time_t begin, time_t end, int step) -{ - const auto & date_lut = DateLUT::instance(); - - for (time_t t = begin; t < end; t += step) - { - time_t t2 = date_lut.makeDateTime( - date_lut.toYear(t), - date_lut.toMonth(t), - date_lut.toDayOfMonth(t), - date_lut.toHour(t), - date_lut.toMinute(t), - date_lut.toSecond(t)); - - std::string s1 = toString(t); - std::string s2 = toString(t2); - - std::cerr << s1 << ", " << s2 << std::endl; - - if (s1 != s2) - throw Poco::Exception("Test failed."); - } -} - - -int main(int argc, char ** argv) -{ - loop(orderedIdentifierToDate(20101031), orderedIdentifierToDate(20101101), 15 * 60); - loop(orderedIdentifierToDate(20100328), orderedIdentifierToDate(20100330), 15 * 60); - - return 0; -} diff --git a/libs/libcommon/src/tests/date_lut4.cpp b/libs/libcommon/src/tests/date_lut4.cpp deleted file mode 100644 index 2fd893a4f05..00000000000 --- a/libs/libcommon/src/tests/date_lut4.cpp +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include - - -int main(int argc, char ** argv) -{ - /** В DateLUT был глюк - для времён из дня 1970-01-01, возвращался номер часа больше 23. */ - static const time_t TIME = 66130; - - const auto & date_lut = DateLUT::instance(); - - std::cerr << date_lut.toHour(TIME) << std::endl; - std::cerr << date_lut.toDayNum(TIME) << std::endl; - - const auto * values = reinterpret_cast(&date_lut); - - std::cerr << values[0].date << ", " << time_t(values[1].date - values[0].date) << std::endl; - - return 0; -} diff --git a/libs/libcommon/src/tests/date_lut_default_timezone.cpp b/libs/libcommon/src/tests/date_lut_default_timezone.cpp deleted file mode 100644 index 89783e6f7f5..00000000000 --- a/libs/libcommon/src/tests/date_lut_default_timezone.cpp +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include - -int main(int argc, char ** argv) -{ - try - { - const auto & date_lut = DateLUT::instance(); - std::cout << "Detected default timezone: `" << date_lut.getTimeZone() << "'" << std::endl; - time_t now = time(NULL); - std::cout << "Current time: " << date_lut.timeToString(now) - << ", UTC: " << DateLUT::instance("UTC").timeToString(now) << std::endl; - } - catch (const Poco::Exception & e) - { - std::cerr << e.displayText() << std::endl; - return 1; - } - catch (std::exception & e) - { - std::cerr << "std::exception: " << e.what() << std::endl; - return 2; - } - catch (...) - { - std::cerr << "Some exception" << std::endl; - return 3; - } - return 0; -} diff --git a/libs/libcommon/src/tests/multi_version.cpp b/libs/libcommon/src/tests/multi_version.cpp deleted file mode 100644 index 003b96e2efb..00000000000 --- a/libs/libcommon/src/tests/multi_version.cpp +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include -#include - - -using T = std::string; -using MV = MultiVersion; -using Results = std::vector; - - -void thread1(MV & x, T & result) -{ - MV::Version v = x.get(); - result = *v; -} - -void thread2(MV & x, const char * result) -{ - x.set(std::make_unique(result)); -} - - -int main(int argc, char ** argv) -{ - try - { - const char * s1 = "Hello!"; - const char * s2 = "Goodbye!"; - - size_t n = 1000; - MV x(std::make_unique(s1)); - Results results(n); - - ThreadPool tp(8); - for (size_t i = 0; i < n; ++i) - { - tp.schedule(std::bind(thread1, std::ref(x), std::ref(results[i]))); - tp.schedule(std::bind(thread2, std::ref(x), (rand() % 2) ? s1 : s2)); - } - tp.wait(); - - for (size_t i = 0; i < n; ++i) - std::cerr << results[i] << " "; - std::cerr << std::endl; - } - catch (const Poco::Exception & e) - { - std::cerr << e.message() << std::endl; - throw; - } - - return 0; -} diff --git a/tests/README.md b/tests/README.md index bd46be3fd48..9286e2ac2c5 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,6 +1,6 @@ ## How to run test cases in local -For running intergration test cases (defined in `./fullstack-test`, `./fullstack-test-dt`, `./new_collation_fullstack`), you should define a TiDB cluster with TiFlash node (1 PD, 1 TiKV, 1 TiDB, 1 TiFlash at least). +For running integration test cases (defined in `./fullstack-test`, `./fullstack-test-dt`, `./new_collation_fullstack`), you should define a TiDB cluster with TiFlash node (1 PD, 1 TiKV, 1 TiDB, 1 TiFlash at least). 1. Build your own TiFlash binary using debug profile: @@ -10,7 +10,7 @@ For running intergration test cases (defined in `./fullstack-test`, `./fullstack ``` 2. Install [tiup](https://tiup.io/) -3. Use `tiup playground` (recommanded) or `tiup cluster` to start a tidb cluster +3. Use `tiup playground` (recommended) or `tiup cluster` to start a tidb cluster ```bash export TIDB_PORT=4000 diff --git a/tests/docker/cluster.yaml b/tests/docker/cluster.yaml index dd2f1dbf33f..0289c6bf601 100644 --- a/tests/docker/cluster.yaml +++ b/tests/docker/cluster.yaml @@ -16,7 +16,7 @@ version: '2.3' services: pd0: - image: hub.pingcap.net/qa/pd:${PD_BRANCH:-master} + image: ${PD_IMAGE:-hub.pingcap.net/tikv/pd/image:master} security_opt: - seccomp:unconfined volumes: @@ -35,7 +35,7 @@ services: - --log-file=/log/pd.log restart: on-failure tikv0: - image: hub.pingcap.net/qa/tikv:${TIKV_BRANCH:-master} + image: ${TIKV_IMAGE:-hub.pingcap.net/tikv/tikv/image:master} security_opt: - seccomp:unconfined volumes: @@ -55,7 +55,7 @@ services: - "pd0" restart: on-failure tidb0: - image: hub.pingcap.net/qa/tidb:${TIDB_BRANCH:-master} + image: ${TIDB_IMAGE:-hub.pingcap.net/pingcap/tidb/images/tidb-server:master} security_opt: - seccomp:unconfined volumes: diff --git a/tests/docker/cluster_tidb_fail_point.yaml b/tests/docker/cluster_tidb_fail_point.yaml index faec3c48898..39a39b55fad 100644 --- a/tests/docker/cluster_tidb_fail_point.yaml +++ b/tests/docker/cluster_tidb_fail_point.yaml @@ -16,7 +16,7 @@ version: '2.3' services: pd0: - image: hub.pingcap.net/qa/pd:${PD_BRANCH:-master} + image: ${PD_IMAGE:-hub.pingcap.net/tikv/pd/image:master} security_opt: - seccomp:unconfined volumes: @@ -35,7 +35,7 @@ services: - --log-file=/log/pd.log restart: on-failure tikv0: - image: hub.pingcap.net/qa/tikv:${TIKV_BRANCH:-master} + image: ${TIKV_IMAGE:-hub.pingcap.net/tikv/tikv/image:master} security_opt: - seccomp:unconfined volumes: @@ -55,7 +55,7 @@ services: - "pd0" restart: on-failure tidb0: - image: hub.pingcap.net/qa/tidb:${TIDB_BRANCH:-master}-failpoint + image: ${TIDB_IMAGE:-hub.pingcap.net/pingcap/tidb/images/tidb-server:master-failpoint} security_opt: - seccomp:unconfined environment: diff --git a/tests/fullstack-test/expr/cast_as_json.test b/tests/fullstack-test/expr/cast_as_json.test index 9e5e6c046a2..1654081e586 100644 --- a/tests/fullstack-test/expr/cast_as_json.test +++ b/tests/fullstack-test/expr/cast_as_json.test @@ -152,7 +152,7 @@ mysql> set @@tidb_isolation_read_engines='tiflash'; set @@tidb_enforce_mpp = 1; | cast(float_s as json) | cast(float_us as json) | +-----------------------+------------------------+ | NULL | NULL | -| 0 | 0 | +| 0.0 | 0.0 | | -999.9990234375 | 999.9990234375 | +-----------------------+------------------------+ mysql> set @@tidb_isolation_read_engines='tiflash'; set @@tidb_enforce_mpp = 1; select cast(double_s as json), cast(double_us as json) from test.test_float; @@ -160,7 +160,7 @@ mysql> set @@tidb_isolation_read_engines='tiflash'; set @@tidb_enforce_mpp = 1; | cast(double_s as json) | cast(double_us as json) | +------------------------+-------------------------+ | NULL | NULL | -| 0 | 0 | +| 0.0 | 0.0 | | -999.999 | 999.999 | +------------------------+-------------------------+ diff --git a/tests/fullstack-test/expr/duration_pushdown.test b/tests/fullstack-test/expr/duration_pushdown.test index e4679956472..028e4872066 100644 --- a/tests/fullstack-test/expr/duration_pushdown.test +++ b/tests/fullstack-test/expr/duration_pushdown.test @@ -121,7 +121,7 @@ mysql> create table test.time_test(id int(11),v1 time(3) not null, v2 time(3)); mysql> insert into test.time_test values(1,'20:20:20','20:20:20'); mysql> alter table test.time_test set tiflash replica 1; func> wait_table test time_test -mysql> use test; set tidb_enforce_mpp=1; set tidb_isolation_read_engines='tiflash'; select distinct v1 from(select v1 from time_test union all select v2 from time_test); +mysql> use test; set tidb_enforce_mpp=1; set tidb_isolation_read_engines='tiflash'; select distinct v1 from(select v1 from time_test union all select v2 from time_test) a; +--------------+ | v1 | +--------------+ diff --git a/tests/fullstack-test/mpp/window_agg.test b/tests/fullstack-test/mpp/window_agg.test index c34710276e5..b403562ca40 100644 --- a/tests/fullstack-test/mpp/window_agg.test +++ b/tests/fullstack-test/mpp/window_agg.test @@ -1418,3 +1418,81 @@ mysql> use test; set tidb_enforce_mpp=1; select p, o, v, sum(v) over w as "sum", | 6 | 3 | 3 | 39 | 9 | 2 | 8 | | 6 | 3 | 4 | 39 | 9 | 2 | 8 | +------+------+------+------+-------+------+------+ + +mysql> use test; set tidb_enforce_mpp=1; select p, o, v, min(v) over w as min_v, max(v) over w as max_v from t3 window w as (partition by p); ++------+------+-------------------------------------------------------------+--------------+-----------------------------------------------+ +| p | o | v | min_v | max_v | ++------+------+-------------------------------------------------------------+--------------+-----------------------------------------------+ +| 3 | 0 | g4g4opuim | | pmgo99f3 | +| 3 | 4 | phiiji | | pmgo99f3 | +| 3 | 6 | h45htrhPh5htrV | | pmgo99f3 | +| 3 | 10 | j7jhgvuev | | pmgo99f3 | +| 3 | 20 | pmgo99f3 | | pmgo99f3 | +| 3 | 40 | | | pmgo99f3 | +| 3 | 41 | TGh54h54htever | | pmgo99f3 | +| 6 | 0 | h45yh5f33 | fg4wfwj6j | h45yh5f33 | +| 6 | 10 | fg4wfwj6j | fg4wfwj6j | h45yh5f33 | +| 6 | 30 | gf34wf | fg4wfwj6j | h45yh5f33 | +| 7 | 0 | GFj6j6j6j6j6j32fwj6j6j6jyFf34 | G4gwh6jh6j6j | f43wg43wFG34y564g5 | +| 7 | 1 | G4wegfy54h5h5 | G4gwh6jh6j6j | f43wg43wFG34y564g5 | +| 7 | 2 | G4gwh6jh6j6j | G4gwh6jh6j6j | f43wg43wFG34y564g5 | +| 7 | 3 | Gh54h45hg6rh3f3 | G4gwh6jh6j6j | f43wg43wFG34y564g5 | +| 7 | 4 | f43wg43wFG34y564g5 | G4gwh6jh6j6j | f43wg43wFG34y564g5 | +| 2 | 0 | gh5ervfdgerbvresgkope4w59ujg430o9ggv4ij6h5eb5by5rv4wfvr | | vnuiwvch23978vgbuysdihvbyifj490whjv4iuow34gvr | +| 2 | 1 | vnuiwvch23978vgbuysdihvbyifj490whjv4iuow34gvr | | vnuiwvch23978vgbuysdihvbyifj490whjv4iuow34gvr | +| 2 | 4 | gre | | vnuiwvch23978vgbuysdihvbyifj490whjv4iuow34gvr | +| 2 | 7 | hg5df | | vnuiwvch23978vgbuysdihvbyifj490whjv4iuow34gvr | +| 2 | 8 | gdve | | vnuiwvch23978vgbuysdihvbyifj490whjv4iuow34gvr | +| 2 | 15 | | | vnuiwvch23978vgbuysdihvbyifj490whjv4iuow34gvr | +| 2 | 30 | gh345werfrf23ewg435g45Og45egrgbwerb4wgvb354vrg45erwgvw34fc4 | | vnuiwvch23978vgbuysdihvbyifj490whjv4iuow34gvr | +| 0 | 0 | | | | +| 1 | 0 | grwe4 | | htxfdsg | +| 1 | 1 | g43d | | htxfdsg | +| 1 | 3 | hg5g4wgdsrv34g24gfthdgf4gberbvf34brebrgfrew | | htxfdsg | +| 1 | 6 | g4sv | | htxfdsg | +| 1 | 7 | g4sfe | | htxfdsg | +| 1 | 8 | htxfdsg | | htxfdsg | +| 1 | 18 | | | htxfdsg | +| 1 | 30 | fw4efv43g5f32ghrtnbt4vt435g9v4v45gFYHWEUHVUEtfg4g4 | | htxfdsg | +| 4 | 0 | Gj6j6jf3gfw4 | Gj6j6jf3gfw4 | Gj6j6jf3gfw4 | +| 5 | 0 | g4effewsgfr | g4effewsgfr | g4effewsgfr | ++------+------+-------------------------------------------------------------+--------------+-----------------------------------------------+ + +mysql> use test; set tidb_enforce_mpp=1; select p, o, v, min(v) over w as min_v, max(v) over w as max_v from t3 window w as (partition by p order by o rows between 2 preceding and 1 following); ++------+------+-------------------------------------------------------------+---------------------------------------------------------+-------------------------------------------------------------+ +| p | o | v | min_v | max_v | ++------+------+-------------------------------------------------------------+---------------------------------------------------------+-------------------------------------------------------------+ +| 4 | 0 | Gj6j6jf3gfw4 | Gj6j6jf3gfw4 | Gj6j6jf3gfw4 | +| 0 | 0 | | | | +| 5 | 0 | g4effewsgfr | g4effewsgfr | g4effewsgfr | +| 2 | 0 | gh5ervfdgerbvresgkope4w59ujg430o9ggv4ij6h5eb5by5rv4wfvr | gh5ervfdgerbvresgkope4w59ujg430o9ggv4ij6h5eb5by5rv4wfvr | vnuiwvch23978vgbuysdihvbyifj490whjv4iuow34gvr | +| 2 | 1 | vnuiwvch23978vgbuysdihvbyifj490whjv4iuow34gvr | gh5ervfdgerbvresgkope4w59ujg430o9ggv4ij6h5eb5by5rv4wfvr | vnuiwvch23978vgbuysdihvbyifj490whjv4iuow34gvr | +| 2 | 4 | gre | gh5ervfdgerbvresgkope4w59ujg430o9ggv4ij6h5eb5by5rv4wfvr | vnuiwvch23978vgbuysdihvbyifj490whjv4iuow34gvr | +| 2 | 7 | hg5df | gdve | vnuiwvch23978vgbuysdihvbyifj490whjv4iuow34gvr | +| 2 | 8 | gdve | | hg5df | +| 2 | 15 | | | hg5df | +| 2 | 30 | gh345werfrf23ewg435g45Og45egrgbwerb4wgvb354vrg45erwgvw34fc4 | | gh345werfrf23ewg435g45Og45egrgbwerb4wgvb354vrg45erwgvw34fc4 | +| 3 | 0 | g4g4opuim | g4g4opuim | phiiji | +| 3 | 4 | phiiji | g4g4opuim | phiiji | +| 3 | 6 | h45htrhPh5htrV | g4g4opuim | phiiji | +| 3 | 10 | j7jhgvuev | h45htrhPh5htrV | pmgo99f3 | +| 3 | 20 | pmgo99f3 | | pmgo99f3 | +| 3 | 40 | | | pmgo99f3 | +| 3 | 41 | TGh54h54htever | | pmgo99f3 | +| 1 | 0 | grwe4 | g43d | grwe4 | +| 1 | 1 | g43d | g43d | hg5g4wgdsrv34g24gfthdgf4gberbvf34brebrgfrew | +| 1 | 3 | hg5g4wgdsrv34g24gfthdgf4gberbvf34brebrgfrew | g43d | hg5g4wgdsrv34g24gfthdgf4gberbvf34brebrgfrew | +| 1 | 6 | g4sv | g43d | hg5g4wgdsrv34g24gfthdgf4gberbvf34brebrgfrew | +| 1 | 7 | g4sfe | g4sfe | htxfdsg | +| 1 | 8 | htxfdsg | | htxfdsg | +| 1 | 18 | | | htxfdsg | +| 1 | 30 | fw4efv43g5f32ghrtnbt4vt435g9v4v45gFYHWEUHVUEtfg4g4 | | htxfdsg | +| 6 | 0 | h45yh5f33 | fg4wfwj6j | h45yh5f33 | +| 6 | 10 | fg4wfwj6j | fg4wfwj6j | h45yh5f33 | +| 6 | 30 | gf34wf | fg4wfwj6j | h45yh5f33 | +| 7 | 0 | GFj6j6j6j6j6j32fwj6j6j6jyFf34 | G4wegfy54h5h5 | GFj6j6j6j6j6j32fwj6j6j6jyFf34 | +| 7 | 1 | G4wegfy54h5h5 | G4gwh6jh6j6j | GFj6j6j6j6j6j32fwj6j6j6jyFf34 | +| 7 | 2 | G4gwh6jh6j6j | G4gwh6jh6j6j | Gh54h45hg6rh3f3 | +| 7 | 3 | Gh54h45hg6rh3f3 | G4gwh6jh6j6j | f43wg43wFG34y564g5 | +| 7 | 4 | f43wg43wFG34y564g5 | G4gwh6jh6j6j | f43wg43wFG34y564g5 | ++------+------+-------------------------------------------------------------+---------------------------------------------------------+-------------------------------------------------------------+ diff --git a/tests/sanitize/asan.suppression b/tests/sanitize/asan.suppression index 9295ff53745..ef26166ef96 100644 --- a/tests/sanitize/asan.suppression +++ b/tests/sanitize/asan.suppression @@ -1 +1,3 @@ leak:fiu_enable +# The CppRawPtr in FastAddPeerRes will be automatically deleted in rust +leak:genFastAddPeerRes diff --git a/asan_ignores.txt b/tests/sanitize/asan_ignores.txt similarity index 100% rename from asan_ignores.txt rename to tests/sanitize/asan_ignores.txt diff --git a/tests/sanitize/tsan.suppression b/tests/sanitize/tsan.suppression index e9e99557ce7..57242747a4a 100644 --- a/tests/sanitize/tsan.suppression +++ b/tests/sanitize/tsan.suppression @@ -8,3 +8,8 @@ race:dbms/src/DataStreams/BlockStreamProfileInfo.h race:StackTrace::toString race:DB::SyncPointCtl::sync race:XXH3_hashLong_64b_withSeed_selection +race:re2::RE2::NumberOfCapturingGroups +# PathPool is used in lot of places, but TiFlashStorageTestBasic::reload will try to write it. +# Since we will never call Context::setPathPool after TiFlash is initialized, it is safe to suppress. +race:TiFlashStorageTestBasic::reload +race:*::~shared_ptr From 733aa1e3cc34517c44d8af670993399fdd5376ce Mon Sep 17 00:00:00 2001 From: gengliqi Date: Sat, 22 Mar 2025 14:39:49 +0800 Subject: [PATCH 03/84] u Signed-off-by: gengliqi --- dbms/src/Storages/DeltaMerge/tests/gtest_inverted_index.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_inverted_index.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_inverted_index.cpp index c0b027bca24..4541d089842 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_inverted_index.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_inverted_index.cpp @@ -151,6 +151,7 @@ class InvertedIndexTest viewer->search(bitmap_filter, key); ASSERT_EQ(bitmap_filter->count(), expected_count); }; + std::vector threads; { for (UInt32 i = 0; i < 10; ++i) From 3d9507bb2f0a3d2d0ba09b94e26b1003dcdddca7 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Mon, 24 Mar 2025 17:51:39 +0800 Subject: [PATCH 04/84] update tests Signed-off-by: gengliqi --- dbms/src/Flash/tests/gtest_join_executor.cpp | 159 ++++++++++++++----- 1 file changed, 119 insertions(+), 40 deletions(-) diff --git a/dbms/src/Flash/tests/gtest_join_executor.cpp b/dbms/src/Flash/tests/gtest_join_executor.cpp index cdde704ec13..f8359cee396 100644 --- a/dbms/src/Flash/tests/gtest_join_executor.cpp +++ b/dbms/src/Flash/tests/gtest_join_executor.cpp @@ -21,6 +21,8 @@ namespace DB namespace FailPoints { extern const char force_semi_join_time_exceed[]; +extern const char force_join_v2_probe_enable_lm[]; +extern const char force_join_v2_probe_disable_lm[]; } // namespace FailPoints namespace tests { @@ -34,6 +36,28 @@ class JoinExecutorTestRunner : public DB::tests::JoinTestRunner /// disable spill context.context->setSetting("max_bytes_before_external_join", Field(static_cast(0))); + initJoinTestConfig(configs); + initJoinTestConfig(enable_lm_configs); + } + + struct JoinTestConfig + { + JoinTestConfig(bool enable_pipeline_, bool enable_join_v2_, UInt64 prefetch_threshold_) + : enable_pipeline(enable_pipeline_) + , enable_join_v2(enable_join_v2_) + , join_v2_prefetch_threshold(prefetch_threshold_) + {} + bool enable_pipeline; + bool enable_join_v2; + UInt64 join_v2_prefetch_threshold; + bool join_v2_enable_lm = false; + }; + std::vector configs; + std::vector enable_lm_configs; + + template + static void initJoinTestConfig(std::vector & cfgs) + { for (auto enable_pipeline : {false, true}) { if (enable_pipeline) @@ -43,29 +67,27 @@ class JoinExecutorTestRunner : public DB::tests::JoinTestRunner if (enable_join_v2) { for (UInt64 prefetch_threshold : {0, 100000000}) - configs.emplace_back(enable_pipeline, enable_join_v2, prefetch_threshold); + { + std::vector enable_lm_vec; + if constexpr (can_enable_lm) + enable_lm_vec = {false, true}; + else + enable_lm_vec = {false}; + for (auto enable_lm : enable_lm_vec) + { + cfgs.emplace_back(enable_pipeline, enable_join_v2, prefetch_threshold); + cfgs.back().join_v2_enable_lm = enable_lm; + } + } } else - configs.emplace_back(enable_pipeline, enable_join_v2, 0); + cfgs.emplace_back(enable_pipeline, enable_join_v2, 0); } } else - configs.emplace_back(enable_pipeline, false, 0); + cfgs.emplace_back(enable_pipeline, false, 0); } } - - struct JoinTestConfig - { - JoinTestConfig(bool enable_pipeline_, bool enable_join_v2_, UInt64 prefetch_threshold_) - : enable_pipeline(enable_pipeline_) - , enable_join_v2(enable_join_v2_) - , prefetch_threshold(prefetch_threshold_) - {} - bool enable_pipeline; - bool enable_join_v2; - UInt64 prefetch_threshold; - }; - std::vector configs; }; #define WRAP_FOR_JOIN_TEST_BEGIN \ @@ -73,10 +95,34 @@ class JoinExecutorTestRunner : public DB::tests::JoinTestRunner { \ enablePipeline(cfg.enable_pipeline); \ context.context->getSettingsRef().enable_hash_join_v2 = cfg.enable_join_v2; \ - context.context->getSettingsRef().join_v2_probe_enable_prefetch_threshold = cfg.prefetch_threshold; + context.context->getSettingsRef().join_v2_probe_enable_prefetch_threshold = cfg.join_v2_prefetch_threshold; #define WRAP_FOR_JOIN_TEST_END } +#define WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_BEGIN \ + for (auto cfg : configs) \ + { \ + enablePipeline(cfg.enable_pipeline); \ + context.context->getSettingsRef().enable_hash_join_v2 = cfg.enable_join_v2; \ + context.context->getSettingsRef().join_v2_probe_enable_prefetch_threshold = cfg.join_v2_prefetch_threshold; \ + if (cfg.enable_join_v2) \ + { \ + if (cfg.join_v2_enable_lm) \ + FailPointHelper::enableFailPoint(FailPoints::force_join_v2_probe_enable_lm); \ + else \ + FailPointHelper::enableFailPoint(FailPoints::force_join_v2_probe_disable_lm); \ + } + +#define WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END \ + if (cfg.enable_join_v2) \ + { \ + if (cfg.join_v2_enable_lm) \ + FailPointHelper::disableFailPoint(FailPoints::force_join_v2_probe_enable_lm); \ + else \ + FailPointHelper::disableFailPoint(FailPoints::force_join_v2_probe_disable_lm); \ + } \ + } + TEST_F(JoinExecutorTestRunner, SimpleJoin) try { @@ -191,7 +237,6 @@ try toNullableVec({0, 0, 0, 1, 1})}, }; - WRAP_FOR_JOIN_TEST_BEGIN std::vector probe_cache_column_threshold{2, 1000}; for (size_t i = 0; i < join_type_num; ++i) { @@ -210,14 +255,16 @@ try context.context->setSetting( "join_probe_cache_columns_threshold", Field(static_cast(threshold))); + + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual(request, expected_cols[i * simple_test_num + j]); ASSERT_COLUMNS_EQ_UR( genScalarCountResults(expected_cols[i * simple_test_num + j]), executeStreams(request_column_prune, 2)); + WRAP_FOR_JOIN_TEST_END } } } - WRAP_FOR_JOIN_TEST_END } CATCH @@ -590,7 +637,6 @@ try toNullableVec({0, 0, 0})}, }; - WRAP_FOR_JOIN_TEST_BEGIN /// select * from (t1 JT1 t2 using (a)) JT2 (t3 JT1 t4 using (a)) using (b) for (auto [i, jt1] : ext::enumerate(join_types)) { @@ -604,7 +650,9 @@ try auto request = t1.join(t2, jt1, {col("a")}).join(t3.join(t4, jt1, {col("a")}), jt2, {col("b")}).build(context); + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual(request, expected_cols[i * join_type_num + j]); + WRAP_FOR_JOIN_TEST_END } { auto t1 = context.scan("multi_test", "t1"); @@ -615,13 +663,14 @@ try .join(t3.join(t4, jt1, {col("a")}), jt2, {col("b")}) .aggregation({Count(lit(static_cast(1)))}, {}) .build(context); + WRAP_FOR_JOIN_TEST_BEGIN ASSERT_COLUMNS_EQ_UR( genScalarCountResults(expected_cols[i * join_type_num + j]), executeStreams(request_column_prune, 2)); + WRAP_FOR_JOIN_TEST_END } } } - WRAP_FOR_JOIN_TEST_END } CATCH @@ -640,8 +689,6 @@ try .build(context); }; - WRAP_FOR_JOIN_TEST_BEGIN - ColumnsWithTypeAndName column_prune_ref_columns; column_prune_ref_columns.push_back(toVec({1})); @@ -650,64 +697,80 @@ try context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeFloat}}, {toVec("a", {1.0})}); + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual(cast_request(), {toNullableVec({1}), toNullableVec({1.0})}); ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); + WRAP_FOR_JOIN_TEST_END /// int(1) == double(1.0) context.addMockTable("cast", "t1", {{"a", TiDB::TP::TypeLong}}, {toVec("a", {1})}); context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeDouble}}, {toVec("a", {1.0})}); + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual(cast_request(), {toNullableVec({1}), toNullableVec({1.0})}); ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); + WRAP_FOR_JOIN_TEST_END /// float(1) == double(1.0) context.addMockTable("cast", "t1", {{"a", TiDB::TP::TypeFloat}}, {toVec("a", {1})}); context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeDouble}}, {toVec("a", {1})}); + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual(cast_request(), {toNullableVec({1}), toNullableVec({1})}); ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); + WRAP_FOR_JOIN_TEST_END /// varchar('x') == char('x') context.addMockTable("cast", "t1", {{"a", TiDB::TP::TypeString}}, {toVec("a", {"x"})}); context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeVarchar}}, {toVec("a", {"x"})}); + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual(cast_request(), {toNullableVec({"x"}), toNullableVec({"x"})}); ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); + WRAP_FOR_JOIN_TEST_END /// tinyblob('x') == varchar('x') context.addMockTable("cast", "t1", {{"a", TiDB::TP::TypeTinyBlob}}, {toVec("a", {"x"})}); context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeVarchar}}, {toVec("a", {"x"})}); + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual(cast_request(), {toNullableVec({"x"}), toNullableVec({"x"})}); ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); + WRAP_FOR_JOIN_TEST_END /// mediumBlob('x') == varchar('x') context.addMockTable("cast", "t1", {{"a", TiDB::TP::TypeMediumBlob}}, {toVec("a", {"x"})}); context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeVarchar}}, {toVec("a", {"x"})}); + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual(cast_request(), {toNullableVec({"x"}), toNullableVec({"x"})}); ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); + WRAP_FOR_JOIN_TEST_END /// blob('x') == varchar('x') context.addMockTable("cast", "t1", {{"a", TiDB::TP::TypeBlob}}, {toVec("a", {"x"})}); context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeVarchar}}, {toVec("a", {"x"})}); + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual(cast_request(), {toNullableVec({"x"}), toNullableVec({"x"})}); ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); + WRAP_FOR_JOIN_TEST_END /// longBlob('x') == varchar('x') context.addMockTable("cast", "t1", {{"a", TiDB::TP::TypeLongBlob}}, {toVec("a", {"x"})}); context.addMockTable("cast", "t2", {{"a", TiDB::TP::TypeVarchar}}, {toVec("a", {"x"})}); + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual(cast_request(), {toNullableVec({"x"}), toNullableVec({"x"})}); ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); + WRAP_FOR_JOIN_TEST_END /// decimal with different scale context.addMockTable( @@ -722,11 +785,13 @@ try {{"a", TiDB::TP::TypeNewDecimal}}, {createColumn(std::make_tuple(9, 3), {"0.12"}, "a")}); + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual( cast_request(), {createNullableColumn(std::make_tuple(65, 0), {"0.12"}, {0}), createNullableColumn(std::make_tuple(65, 0), {"0.12"}, {0})}); ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request(), 2)); + WRAP_FOR_JOIN_TEST_END /// datetime(1970-01-01 00:00:01) == timestamp(1970-01-01 00:00:01) context.addMockTable( @@ -752,12 +817,11 @@ try .aggregation({Count(lit(static_cast(1)))}, {}) .build(context); }; + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual( cast_request_1(), {createDateTimeColumn({{{1970, 1, 1, 0, 0, 1, 0}}}, 0), createDateTimeColumn({{{1970, 1, 1, 0, 0, 1, 0}}}, 0)}); - ASSERT_COLUMNS_EQ_UR(column_prune_ref_columns, executeStreams(cast_column_prune_request_1(), 2)); - WRAP_FOR_JOIN_TEST_END } CATCH @@ -799,7 +863,6 @@ try toNullableVec({1, 4})}, }; - WRAP_FOR_JOIN_TEST_BEGIN for (auto [i, tp] : ext::enumerate(join_types)) { auto request = context.scan("join_agg", "t1") @@ -807,9 +870,10 @@ try .aggregation({Max(col("a")), Min(col("a")), Count(col("a"))}, {col("b")}) .build(context); + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual(request, expected_cols[i]); + WRAP_FOR_JOIN_TEST_END } - WRAP_FOR_JOIN_TEST_END } CATCH @@ -1512,7 +1576,6 @@ try {{"id", TiDB::TP::TypeLongLong}, {"probe_value", TiDB::TP::TypeLongLong}}, {probe_key, probe_col}); - WRAP_FOR_JOIN_TEST_BEGIN context.context->setSetting("max_block_size", Field(static_cast(90))); { auto anti_join_request = context.scan("issue_8791", "probe_table") @@ -1528,7 +1591,9 @@ try .build(context); auto expected_columns = {toVec({16})}; + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_BEGIN ASSERT_COLUMNS_EQ_UR(expected_columns, executeStreams(anti_join_request, 1)); + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END } { auto inner_join_request = context.scan("issue_8791", "probe_table") @@ -1544,9 +1609,10 @@ try .build(context); auto expected_columns = {toVec({240})}; + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_BEGIN ASSERT_COLUMNS_EQ_UR(expected_columns, executeStreams(inner_join_request, 1)); + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END } - WRAP_FOR_JOIN_TEST_END } CATCH @@ -1911,6 +1977,7 @@ try .aggregation({Count(lit(static_cast(1)))}, {}) .build(context); { + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual( request, {toNullableVec({"banana", "banana"}), @@ -1918,6 +1985,7 @@ try toNullableVec({"banana", "banana"}), toNullableVec({"apple", "banana"})}); ASSERT_COLUMNS_EQ_UR(genScalarCountResults(2), executeStreams(request_column_prune, 2)); + WRAP_FOR_JOIN_TEST_END } request = context.scan("test_db", "l_table") @@ -1925,21 +1993,25 @@ try .project({"s", "join_c"}) .build(context); { + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual( request, {toNullableVec({"banana", "banana"}), toNullableVec({"apple", "banana"})}); + WRAP_FOR_JOIN_TEST_END } request = context.scan("test_db", "l_table") .join(context.scan("test_db", "r_table_2"), tipb::JoinType::TypeLeftOuterJoin, {col("join_c")}) .build(context); { + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual( request, {toNullableVec({"banana", "banana", "banana", "banana"}), toNullableVec({"apple", "apple", "apple", "banana"}), toNullableVec({"banana", "banana", "banana", {}}), toNullableVec({"apple", "apple", "apple", {}})}); + WRAP_FOR_JOIN_TEST_END } } CATCH @@ -1959,9 +2031,9 @@ try {}) .aggregation({Count(lit(static_cast(1)))}, {}) .build(context); - WRAP_FOR_JOIN_TEST_BEGIN + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_BEGIN ASSERT_COLUMNS_EQ_UR(genScalarCountResults(2), executeStreams(request_column, 2)); - WRAP_FOR_JOIN_TEST_END + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END } CATCH @@ -2022,24 +2094,31 @@ try std::shared_ptr request; std::shared_ptr request_column_prune; - WRAP_FOR_JOIN_TEST_BEGIN // inner join { // null table join non-null table request = context.scan("null_test", "null_table") .join(context.scan("null_test", "t"), tipb::JoinType::TypeInnerJoin, {col("a")}) .build(context); - executeAndAssertColumnsEqual(request, {}); request_column_prune = context.scan("null_test", "null_table") .join(context.scan("null_test", "t"), tipb::JoinType::TypeInnerJoin, {col("a")}) .aggregation({Count(lit(static_cast(1)))}, {}) .build(context); + WRAP_FOR_JOIN_TEST_BEGIN + executeAndAssertColumnsEqual(request, {}); ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); + WRAP_FOR_JOIN_TEST_END // non-null table join null table request = context.scan("null_test", "t") .join(context.scan("null_test", "null_table"), tipb::JoinType::TypeInnerJoin, {col("a")}) .build(context); + request_column_prune + = context.scan("null_test", "t") + .join(context.scan("null_test", "null_table"), tipb::JoinType::TypeInnerJoin, {col("a")}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual( request, {toNullableVec({}), @@ -2048,24 +2127,22 @@ try toNullableVec({}), toNullableVec({}), toNullableVec({})}); - request_column_prune - = context.scan("null_test", "t") - .join(context.scan("null_test", "null_table"), tipb::JoinType::TypeInnerJoin, {col("a")}) - .aggregation({Count(lit(static_cast(1)))}, {}) - .build(context); ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); + WRAP_FOR_JOIN_TEST_END // null table join null table request = context.scan("null_test", "null_table") .join(context.scan("null_test", "null_table"), tipb::JoinType::TypeInnerJoin, {col("a")}) .build(context); - executeAndAssertColumnsEqual(request, {}); request_column_prune = context.scan("null_test", "null_table") .join(context.scan("null_test", "null_table"), tipb::JoinType::TypeInnerJoin, {col("a")}) .aggregation({Count(lit(static_cast(1)))}, {}) .build(context); + WRAP_FOR_JOIN_TEST_BEGIN + executeAndAssertColumnsEqual(request, {}); ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); + WRAP_FOR_JOIN_TEST_END } // cross join @@ -2519,7 +2596,6 @@ try .build(context); ASSERT_COLUMNS_EQ_UR(genScalarCountResults(0), executeStreams(request_column_prune, 2)); } - WRAP_FOR_JOIN_TEST_END } CATCH @@ -4532,5 +4608,8 @@ CATCH #undef WRAP_FOR_JOIN_TEST_BEGIN #undef WRAP_FOR_JOIN_TEST_END +#undef WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_BEGIN +#undef WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END + } // namespace tests } // namespace DB From 44a76d122dac3e4f6be09743e74152cb43b464ea Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 25 Mar 2025 03:37:36 +0800 Subject: [PATCH 05/84] update left-outer with other condition tests Signed-off-by: gengliqi --- dbms/src/Flash/tests/gtest_join_executor.cpp | 209 +++++++++++++++--- dbms/src/Interpreters/JoinV2/HashJoin.cpp | 2 +- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 9 +- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 11 +- 4 files changed, 194 insertions(+), 37 deletions(-) diff --git a/dbms/src/Flash/tests/gtest_join_executor.cpp b/dbms/src/Flash/tests/gtest_join_executor.cpp index f8359cee396..312a63738c3 100644 --- a/dbms/src/Flash/tests/gtest_join_executor.cpp +++ b/dbms/src/Flash/tests/gtest_join_executor.cpp @@ -42,12 +42,10 @@ class JoinExecutorTestRunner : public DB::tests::JoinTestRunner struct JoinTestConfig { - JoinTestConfig(bool enable_pipeline_, bool enable_join_v2_, UInt64 prefetch_threshold_) - : enable_pipeline(enable_pipeline_) - , enable_join_v2(enable_join_v2_) + JoinTestConfig(bool enable_join_v2_, UInt64 prefetch_threshold_) + : enable_join_v2(enable_join_v2_) , join_v2_prefetch_threshold(prefetch_threshold_) {} - bool enable_pipeline; bool enable_join_v2; UInt64 join_v2_prefetch_threshold; bool join_v2_enable_lm = false; @@ -58,34 +56,26 @@ class JoinExecutorTestRunner : public DB::tests::JoinTestRunner template static void initJoinTestConfig(std::vector & cfgs) { - for (auto enable_pipeline : {false, true}) + for (auto enable_join_v2 : {false, true}) { - if (enable_pipeline) + if (enable_join_v2) { - for (auto enable_join_v2 : {false, true}) + for (UInt64 prefetch_threshold : {0, 100000000}) { - if (enable_join_v2) + std::vector enable_lm_vec; + if constexpr (can_enable_lm) + enable_lm_vec = {false, true}; + else + enable_lm_vec = {false}; + for (auto enable_lm : enable_lm_vec) { - for (UInt64 prefetch_threshold : {0, 100000000}) - { - std::vector enable_lm_vec; - if constexpr (can_enable_lm) - enable_lm_vec = {false, true}; - else - enable_lm_vec = {false}; - for (auto enable_lm : enable_lm_vec) - { - cfgs.emplace_back(enable_pipeline, enable_join_v2, prefetch_threshold); - cfgs.back().join_v2_enable_lm = enable_lm; - } - } + cfgs.emplace_back(enable_join_v2, prefetch_threshold); + cfgs.back().join_v2_enable_lm = enable_lm; } - else - cfgs.emplace_back(enable_pipeline, enable_join_v2, 0); } } else - cfgs.emplace_back(enable_pipeline, false, 0); + cfgs.emplace_back(enable_join_v2, 0); } } }; @@ -93,16 +83,14 @@ class JoinExecutorTestRunner : public DB::tests::JoinTestRunner #define WRAP_FOR_JOIN_TEST_BEGIN \ for (auto cfg : configs) \ { \ - enablePipeline(cfg.enable_pipeline); \ context.context->getSettingsRef().enable_hash_join_v2 = cfg.enable_join_v2; \ context.context->getSettingsRef().join_v2_probe_enable_prefetch_threshold = cfg.join_v2_prefetch_threshold; #define WRAP_FOR_JOIN_TEST_END } #define WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_BEGIN \ - for (auto cfg : configs) \ + for (auto cfg : enable_lm_configs) \ { \ - enablePipeline(cfg.enable_pipeline); \ context.context->getSettingsRef().enable_hash_join_v2 = cfg.enable_join_v2; \ context.context->getSettingsRef().join_v2_probe_enable_prefetch_threshold = cfg.join_v2_prefetch_threshold; \ if (cfg.enable_join_v2) \ @@ -2019,7 +2007,7 @@ CATCH TEST_F(JoinExecutorTestRunner, LeftJoinAggWithOtherCondition) try { - auto request_column + auto request = context.scan("test_db", "l_table") .join( context.scan("test_db", "r_table"), @@ -2032,7 +2020,170 @@ try .aggregation({Count(lit(static_cast(1)))}, {}) .build(context); WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_BEGIN - ASSERT_COLUMNS_EQ_UR(genScalarCountResults(2), executeStreams(request_column, 2)); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(2), executeStreams(request, 2)); + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END +} +CATCH + +TEST_F(JoinExecutorTestRunner, LeftOuterJoin) +try +{ + context.addMockTable( + {"test_db", "lj_r_table"}, + { + {"r1", TiDB::TP::TypeString}, + {"k1", TiDB::TP::TypeLongLong}, + {"k2", TiDB::TP::TypeShort}, + {"r2", TiDB::TP::TypeString}, + {"r3", TiDB::TP::TypeLong}, + }, + { + toVec("r1", {"apple", "banana", "cat", "dog", "elephant", "frag"}), + toVec("k1", {1, 1, 2, 3, 2, 4}), + toVec("k2", {1, 1, 2, 3, 2, 5}), + toVec("r2", {"aaa", "bbb", "ccc", "ddd", "eee", "fff"}), + toVec("r3", {1, 2, 3, 4, 5, 6}), + }); + + context.addMockTable( + {"test_db", "lj_l_table"}, + { + {"l1", TiDB::TP::TypeString}, + {"k1", TiDB::TP::TypeLongLong}, + {"k2", TiDB::TP::TypeShort}, + {"l2", TiDB::TP::TypeLong}, + {"l3", TiDB::TP::TypeLong}, + {"l4", TiDB::TP::TypeLongLong}, + }, + { + toVec("l1", {"AAA", "BBB", "CCC", "DDD", "EEE", "FFF", "GGG", "HHH", "III", "JJJ", "KKK", "LLL"}), + toNullableVec("k1", {1, 1, 2, 2, {}, 3, 3, 3, 2, 3, 1, 2}), + toNullableVec("k2", {1, {}, 2, 9, 2, 3, 3, 3, 2, 3, 1, 2}), + toVec("l2", {1, 2, 3, 4, 5, 6, 6, 6, 7, 8, 9, 10}), + toNullableVec("l3", {1, 2, 3, 4, 5, 0, {}, 6, 7, 8, 9, 10}), + toNullableVec("l4", {3, 1, 2, 0, 2, 3, 3, {}, 4, 5, 0, 6}), + }); + + // No other condition + auto request = context.scan("test_db", "lj_l_table") + .join( + context.scan("test_db", "lj_r_table"), + tipb::JoinType::TypeLeftOuterJoin, + {col("k1"), col("k2")}, + {eq(col("lj_l_table.l2"), col("lj_l_table.l3"))}, + {}, + {}, + {}) + .project( + {"lj_r_table.r1", + "lj_r_table.r2", + "lj_r_table.r3", + "lj_r_table.k2", + "lj_l_table.k2", + "lj_l_table.l1", + "lj_l_table.l2", + "lj_l_table.l4"}) + .build(context); + WRAP_FOR_JOIN_TEST_BEGIN + executeAndAssertColumnsEqual( + request, + { + toNullableVec( + {"apple", + "banana", + {}, + "cat", + "elephant", + {}, + {}, + {}, + {}, + "dog", + "cat", + "elephant", + "dog", + "apple", + "banana", + "cat", + "elephant"}), + toNullableVec( + {"aaa", + "bbb", + {}, + "ccc", + "eee", + {}, + {}, + {}, + {}, + "ddd", + "ccc", + "eee", + "ddd", + "aaa", + "bbb", + "ccc", + "eee"}), + toNullableVec({1, 2, {}, 3, 5, {}, {}, {}, {}, 4, 3, 5, 4, 1, 2, 3, 5}), + toNullableVec({1, 1, {}, 2, 2, {}, {}, {}, {}, 3, 2, 2, 3, 1, 1, 2, 2}), + toNullableVec({1, 1, {}, 2, 2, 9, 2, 3, 3, 3, 2, 2, 3, 1, 1, 2, 2}), + toNullableVec( + {"AAA", + "AAA", + "BBB", + "CCC", + "CCC", + "DDD", + "EEE", + "FFF", + "GGG", + "HHH", + "III", + "III", + "JJJ", + "KKK", + "KKK", + "LLL", + "LLL"}), + toNullableVec({1, 1, 2, 3, 3, 4, 5, 6, 6, 6, 7, 7, 8, 9, 9, 10, 10}), + toNullableVec({3, 3, 1, 2, 2, 0, 2, 3, 3, {}, 4, 4, 5, 0, 0, 6, 6}), + }); + WRAP_FOR_JOIN_TEST_END + + // Has other condition + request = context.scan("test_db", "lj_l_table") + .join( + context.scan("test_db", "lj_r_table"), + tipb::JoinType::TypeLeftOuterJoin, + {col("k1"), col("k2")}, + {eq(col("lj_l_table.l2"), col("lj_l_table.l3"))}, + {}, + {gt(col("lj_l_table.l4"), col("lj_r_table.r3"))}, + {}) + .project( + {"lj_r_table.r1", + "lj_r_table.r2", + "lj_r_table.r3", + "lj_r_table.k2", + "lj_l_table.k2", + "lj_l_table.l1", + "lj_l_table.l2", + "lj_l_table.l4"}) + .build(context); + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_BEGIN + executeAndAssertColumnsEqual( + request, + { + toNullableVec({"apple", "banana", {}, {}, {}, {}, {}, {}, {}, "cat", "dog", {}, "cat", "elephant"}), + toNullableVec({"aaa", "bbb", {}, {}, {}, {}, {}, {}, {}, "ccc", "ddd", {}, "ccc", "eee"}), + toNullableVec({1, 2, {}, {}, {}, {}, {}, {}, {}, 3, 4, {}, 3, 5}), + toNullableVec({1, 1, {}, {}, {}, {}, {}, {}, {}, 2, 3, {}, 2, 2}), + toNullableVec({1, 1, {}, 2, 9, 2, 3, 3, 3, 2, 3, 1, 2, 2}), + toNullableVec( + {"AAA", "AAA", "BBB", "CCC", "DDD", "EEE", "FFF", "GGG", "HHH", "III", "JJJ", "KKK", "LLL", "LLL"}), + toNullableVec({1, 1, 2, 3, 4, 5, 6, 6, 6, 7, 8, 9, 10, 10}), + toNullableVec({3, 3, 1, 2, 0, 2, 3, 3, {}, 4, 5, 0, 6, 6}), + }); WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END } CATCH diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index 8c7cc817793..4be1d5d9b70 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -466,7 +466,7 @@ void HashJoin::workAfterBuildRowFinish() LOG_DEBUG( log, - "allocate pointer table and init join probe helerp cost {}ms, rows {}, pointer table size {}, " + "allocate pointer table and init join probe helper cost {}ms, rows {}, pointer table size {}, " "added column num {}, enable prefetch {}, enable tagged pointer {}, " "enable late materialization {}(avg size {})", watch.elapsedMilliseconds(), diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index fc5a4f010ed..18a6a761a3f 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -153,7 +153,7 @@ struct ProbeAdder(wd, added_columns); - helper.fillNullMapWithZero(added_columns); + helper.fillNullMapWithZero(added_columns); } }; @@ -198,7 +198,7 @@ struct ProbeAdder(wd, added_columns); - helper.fillNullMapWithZero(added_columns); + helper.fillNullMapWithZero(added_columns); if constexpr (!has_other_condition) { @@ -918,7 +918,7 @@ Block JoinProbeBlockHelper::handleOtherConditions( { size_t end = pos + step > start + length ? start + length : pos + step; wd.insert_batch.clear(); - wd.insert_batch.insert(&wd.row_ptrs_for_lm[pos], &wd.row_ptrs_for_lm[end]); + wd.insert_batch.insert(&wd.filter_row_ptrs_for_lm[pos], &wd.filter_row_ptrs_for_lm[end]); for (size_t i = other_column_indexes_start; i < other_column_indexes_size; ++i) { size_t column_index = row_layout.other_column_indexes[i].first; @@ -1013,7 +1013,7 @@ Block JoinProbeBlockHelper::handleOtherConditions( return res_block; } - if (kind == LeftOuter) + if (kind == LeftOuter && context.isProbeFinished()) return fillNotMatchedRowsForLeftOuter(context, wd); return output_block_after_finalize; @@ -1071,7 +1071,6 @@ Block JoinProbeBlockHelper::fillNotMatchedRowsForLeftOuter(JoinProbeContext & co } size_t remaining_insert_size = settings.max_block_size - wd.result_block_for_other_condition.rows(); - ; size_t result_size = context.not_matched_offsets.size() - context.not_matched_offsets_idx; size_t length = std::min(result_size, remaining_insert_size); diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index dbdcb28e795..7ea16926341 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -269,21 +269,28 @@ class JoinProbeBlockHelper wd.insert_batch.clear(); } - template + template void ALWAYS_INLINE fillNullMapWithZero(MutableColumns & added_columns) const { if constexpr (has_null_key) { + size_t idx = 0; for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) { if (is_nullable) { - auto & nullable_column = static_cast(*added_columns[column_index]); + size_t index; + if constexpr (late_materialization) + index = idx; + else + index = column_index; + auto & nullable_column = static_cast(*added_columns[index]); size_t data_size = nullable_column.getNestedColumn().size(); size_t nullmap_size = nullable_column.getNullMapColumn().size(); RUNTIME_CHECK(nullmap_size <= data_size); nullable_column.getNullMapColumn().getData().resize_fill_zero(data_size); } + ++idx; } } } From e4cbb155cdd030c4691a0c0055817704fec19a2a Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 25 Mar 2025 14:45:17 +0800 Subject: [PATCH 06/84] u Signed-off-by: gengliqi --- dbms/src/Columns/IColumn.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/dbms/src/Columns/IColumn.h b/dbms/src/Columns/IColumn.h index af3be2405d1..13450adadee 100644 --- a/dbms/src/Columns/IColumn.h +++ b/dbms/src/Columns/IColumn.h @@ -258,23 +258,23 @@ class IColumn : public COWPtr virtual void countSerializeByteSize(PaddedPODArray & /* byte_size */) const = 0; virtual void countSerializeByteSizeForCmp( PaddedPODArray & /* byte_size */, - const NullMap * /*nullmap*/, + const NullMap * /* nullmap */, const TiDB::TiDBCollatorPtr & /* collator */) const = 0; /// Count the serialize byte size and added to the byte_size called by ColumnArray. /// array_offsets is the offsets of ColumnArray. /// The byte_size.size() must be equal to the array_offsets.size(). + virtual void countSerializeByteSizeForColumnArray( + PaddedPODArray & /* byte_size */, + const Offsets & /* array_offsets */) const + = 0; virtual void countSerializeByteSizeForCmpColumnArray( PaddedPODArray & /* byte_size */, const Offsets & /* array_offsets */, - const NullMap * /*nullmap*/, + const NullMap * /* nullmap */, const TiDB::TiDBCollatorPtr & /* collator */) const = 0; - virtual void countSerializeByteSizeForColumnArray( - PaddedPODArray & /* byte_size */, - const Offsets & /* array_offsets */) const - = 0; /// Serialize data of column from start to start + length into pointer of pos and forward each pos[i] to the end of /// serialized data. @@ -295,7 +295,7 @@ class IColumn : public COWPtr size_t /* start */, size_t /* length */, bool /* has_null */, - const NullMap * /*nullmap*/, + const NullMap * /* nullmap */, const TiDB::TiDBCollatorPtr & /* collator */, String * /* sort_key_container */) const = 0; @@ -316,7 +316,7 @@ class IColumn : public COWPtr size_t /* start */, size_t /* length */, bool /* has_null */, - const NullMap * /*nullmap*/, + const NullMap * /* nullmap */, const Offsets & /* array_offsets */, const TiDB::TiDBCollatorPtr & /* collator */, String * /* sort_key_container */) const From 147136ec0006e5b7230b54d9fe0bce442c823624 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 25 Mar 2025 15:46:37 +0800 Subject: [PATCH 07/84] save Signed-off-by: gengliqi --- dbms/src/Columns/ColumnAggregateFunction.h | 70 ++++++++++--------- dbms/src/Columns/ColumnArray.h | 56 ++++++++------- dbms/src/Columns/ColumnConst.h | 69 +++++++++--------- dbms/src/Columns/ColumnDecimal.cpp | 14 ++-- dbms/src/Columns/ColumnDecimal.h | 34 +++++---- dbms/src/Columns/ColumnFixedString.cpp | 10 ++- dbms/src/Columns/ColumnFixedString.h | 41 +++++------ dbms/src/Columns/ColumnFunction.h | 66 ++++++++--------- dbms/src/Columns/ColumnNullable.cpp | 16 ++--- dbms/src/Columns/ColumnNullable.h | 31 ++++---- dbms/src/Columns/ColumnString.h | 38 +++++----- dbms/src/Columns/ColumnTuple.h | 66 ++++++++--------- dbms/src/Columns/ColumnVector.h | 46 ++++++------ dbms/src/Columns/IColumn.h | 35 ++++++---- dbms/src/Columns/IColumnDummy.h | 69 +++++++++--------- .../Columns/tests/gtest_column_insertFrom.cpp | 8 +++ 16 files changed, 353 insertions(+), 316 deletions(-) diff --git a/dbms/src/Columns/ColumnAggregateFunction.h b/dbms/src/Columns/ColumnAggregateFunction.h index eb86da9d3f2..e0b30771f9c 100644 --- a/dbms/src/Columns/ColumnAggregateFunction.h +++ b/dbms/src/Columns/ColumnAggregateFunction.h @@ -131,10 +131,12 @@ class ColumnAggregateFunction final : public COWPtrHelper= start + length); + for (size_t i = start; i < start + length; ++i) + insertFrom(src_, selective_offsets[i]); } void insertFrom(ConstAggregateDataPtr __restrict place); @@ -165,6 +167,10 @@ class ColumnAggregateFunction final : public COWPtrHelper & /* byte_size */) const override + { + throw Exception("Method countSerializeByteSize is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + } void countSerializeByteSizeForCmp( PaddedPODArray & /* byte_size */, const NullMap * /*nullmap*/, @@ -174,9 +180,14 @@ class ColumnAggregateFunction final : public COWPtrHelper & /* byte_size */) const override + + void countSerializeByteSizeForColumnArray( + PaddedPODArray & /* byte_size */, + const IColumn::Offsets & /* offsets */) const override { - throw Exception("Method countSerializeByteSize is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + throw Exception( + "Method countSerializeByteSizeForColumnArray is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); } void countSerializeByteSizeForCmpColumnArray( PaddedPODArray & /* byte_size */, @@ -188,15 +199,15 @@ class ColumnAggregateFunction final : public COWPtrHelper & /* byte_size */, - const IColumn::Offsets & /* offsets */) const override + + void serializeToPos( + PaddedPODArray & /* pos */, + size_t /* start */, + size_t /* length */, + bool /* has_null */) const override { - throw Exception( - "Method countSerializeByteSizeForColumnArray is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); + throw Exception("Method serializeToPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void serializeToPosForCmp( PaddedPODArray & /* pos */, size_t /* start */, @@ -208,13 +219,17 @@ class ColumnAggregateFunction final : public COWPtrHelper & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */) const override + bool /* has_null */, + const IColumn::Offsets & /* offsets */) const override { - throw Exception("Method serializeToPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + throw Exception( + "Method serializeToPosForColumnArray is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); } void serializeToPosForCmpColumnArray( PaddedPODArray & /* pos */, @@ -230,47 +245,36 @@ class ColumnAggregateFunction final : public COWPtrHelper & /* pos */, - size_t /* start */, - size_t /* length */, - bool /* has_null */, - const IColumn::Offsets & /* offsets */) const override + + void deserializeAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) override { throw Exception( - "Method serializeToPosForColumnArray is not supported for " + getName(), + "Method deserializeAndInsertFromPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeForCmpAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) override { throw Exception( "Method deserializeForCmpAndInsertFromPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) override - { - throw Exception( - "Method deserializeAndInsertFromPos is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); - } - void deserializeForCmpAndInsertFromPosColumnArray( + void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & /* pos */, const IColumn::Offsets & /* array_offsets */, bool /* use_nt_align_buffer */) override { throw Exception( - "Method deserializeForCmpAndInsertFromPosColumnArray is not supported for " + getName(), + "Method deserializeAndInsertFromPosForColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeAndInsertFromPosForColumnArray( + void deserializeForCmpAndInsertFromPosColumnArray( PaddedPODArray & /* pos */, const IColumn::Offsets & /* array_offsets */, bool /* use_nt_align_buffer */) override { throw Exception( - "Method deserializeAndInsertFromPosForColumnArray is not supported for " + getName(), + "Method deserializeForCmpAndInsertFromPosColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } diff --git a/dbms/src/Columns/ColumnArray.h b/dbms/src/Columns/ColumnArray.h index e719f5a25d7..572523ce14a 100644 --- a/dbms/src/Columns/ColumnArray.h +++ b/dbms/src/Columns/ColumnArray.h @@ -100,31 +100,32 @@ class ColumnArray final : public COWPtrHelper String &) const override; const char * deserializeAndInsertFromArena(const char * pos, const TiDB::TiDBCollatorPtr &) override; + void countSerializeByteSize(PaddedPODArray & byte_size) const override; void countSerializeByteSizeForCmp( PaddedPODArray & byte_size, const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator) const override; - void countSerializeByteSize(PaddedPODArray & byte_size) const override; - void countSerializeByteSizeForCmpColumnArray( + void countSerializeByteSizeForColumnArray( PaddedPODArray & /* byte_size */, - const IColumn::Offsets & /* array_offsets */, - const NullMap * /*nullmap*/, - const TiDB::TiDBCollatorPtr & /* collator */) const override + const IColumn::Offsets & /* array_offsets */) const override { throw Exception( - "Method countSerializeByteSizeForCmpColumnArray is not supported for " + getName(), + "Method countSerializeByteSizeForColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void countSerializeByteSizeForColumnArray( + void countSerializeByteSizeForCmpColumnArray( PaddedPODArray & /* byte_size */, - const IColumn::Offsets & /* array_offsets */) const override + const IColumn::Offsets & /* array_offsets */, + const NullMap * /*nullmap*/, + const TiDB::TiDBCollatorPtr & /* collator */) const override { throw Exception( - "Method countSerializeByteSizeForColumnArray is not supported for " + getName(), + "Method countSerializeByteSizeForCmpColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } + void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; void serializeToPosForCmp( PaddedPODArray & pos, size_t start, @@ -133,53 +134,52 @@ class ColumnArray final : public COWPtrHelper const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const override; - void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; - void serializeToPosForCmpColumnArray( + void serializeToPosForColumnArray( PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, bool /* has_null */, - const NullMap * /* nullmap */, - const IColumn::Offsets & /* array_offsets */, - const TiDB::TiDBCollatorPtr & /* collator */, - String * /* sort_key_container */) const override + const IColumn::Offsets & /* array_offsets */) const override { throw Exception( - "Method serializeToPosForCmpColumnArray is not supported for " + getName(), + "Method serializeToPosForColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void serializeToPosForColumnArray( + void serializeToPosForCmpColumnArray( PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, bool /* has_null */, - const IColumn::Offsets & /* array_offsets */) const override + const NullMap * /* nullmap */, + const IColumn::Offsets & /* array_offsets */, + const TiDB::TiDBCollatorPtr & /* collator */, + String * /* sort_key_container */) const override { throw Exception( - "Method serializeToPosForColumnArray is not supported for " + getName(), + "Method serializeToPosForCmpColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; + void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; - void deserializeForCmpAndInsertFromPosColumnArray( + void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & /* pos */, const IColumn::Offsets & /* array_offsets */, bool /* use_nt_align_buffer */) override { throw Exception( - "Method deserializeForCmpAndInsertFromPosColumnArray is not supported for " + getName(), + "Method deserializeAndInsertFromPosForColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeAndInsertFromPosForColumnArray( + void deserializeForCmpAndInsertFromPosColumnArray( PaddedPODArray & /* pos */, const IColumn::Offsets & /* array_offsets */, bool /* use_nt_align_buffer */) override { throw Exception( - "Method deserializeAndInsertFromPosForColumnArray is not supported for " + getName(), + "Method deserializeForCmpAndInsertFromPosColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } @@ -200,10 +200,12 @@ class ColumnArray final : public COWPtrHelper insertFrom(src_, n); } - void insertSelectiveFrom(const IColumn & src_, const Offsets & selective_offsets) override + void insertSelectiveRangeFrom(const IColumn & src_, const Offsets & selective_offsets, size_t start, size_t length) + override { - for (auto position : selective_offsets) - insertFrom(src_, position); + RUNTIME_CHECK(selective_offsets.size() >= start + length); + for (size_t i = start; i < start + length; ++i) + insertFrom(src_, selective_offsets[i]); } void insertDefault() override; diff --git a/dbms/src/Columns/ColumnConst.h b/dbms/src/Columns/ColumnConst.h index 6f4ef69dc54..41d44c8ab89 100644 --- a/dbms/src/Columns/ColumnConst.h +++ b/dbms/src/Columns/ColumnConst.h @@ -80,10 +80,7 @@ class ColumnConst final : public COWPtrHelper void insertManyFrom(const IColumn &, size_t, size_t length) override { s += length; } - void insertSelectiveFrom(const IColumn &, const Offsets & selective_offsets) override - { - s += selective_offsets.size(); - } + void insertSelectiveRangeFrom(const IColumn &, const Offsets &, size_t, size_t length) override { s += length; } void insertMany(const Field &, size_t length) override { s += length; } @@ -112,6 +109,10 @@ class ColumnConst final : public COWPtrHelper return res; } + void countSerializeByteSize(PaddedPODArray & /* byte_size */) const override + { + throw Exception("Method countSerializeByteSize is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + } void countSerializeByteSizeForCmp( PaddedPODArray & /* byte_size */, const NullMap * /*nullmap*/, @@ -121,11 +122,15 @@ class ColumnConst final : public COWPtrHelper "Method countSerializeByteSizeForCmp is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void countSerializeByteSize(PaddedPODArray & /* byte_size */) const override + + void countSerializeByteSizeForColumnArray( + PaddedPODArray & /* byte_size */, + const IColumn::Offsets & /* array_offsets */) const override { - throw Exception("Method countSerializeByteSize is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + throw Exception( + "Method countSerializeByteSizeForColumnArray is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); } - void countSerializeByteSizeForCmpColumnArray( PaddedPODArray & /* byte_size */, const IColumn::Offsets & /* array_offsets */, @@ -136,15 +141,15 @@ class ColumnConst final : public COWPtrHelper "Method countSerializeByteSizeForCmpColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void countSerializeByteSizeForColumnArray( - PaddedPODArray & /* byte_size */, - const IColumn::Offsets & /* array_offsets */) const override + + void serializeToPos( + PaddedPODArray & /* pos */, + size_t /* start */, + size_t /* length */, + bool /* has_null */) const override { - throw Exception( - "Method countSerializeByteSizeForColumnArray is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); + throw Exception("Method serializeToPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void serializeToPosForCmp( PaddedPODArray & /* pos */, size_t /* start */, @@ -156,15 +161,18 @@ class ColumnConst final : public COWPtrHelper { throw Exception("Method serializeToPosForCmp is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void serializeToPos( + + void serializeToPosForColumnArray( PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */) const override + bool /* has_null */, + const IColumn::Offsets & /* array_offsets */) const override { - throw Exception("Method serializeToPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + throw Exception( + "Method serializeToPosForColumnArray is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); } - void serializeToPosForCmpColumnArray( PaddedPODArray & /* pos */, size_t /* start */, @@ -179,47 +187,36 @@ class ColumnConst final : public COWPtrHelper "Method serializeToPosForCmpColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void serializeToPosForColumnArray( - PaddedPODArray & /* pos */, - size_t /* start */, - size_t /* length */, - bool /* has_null */, - const IColumn::Offsets & /* array_offsets */) const override + + void deserializeAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) override { throw Exception( - "Method serializeToPosForColumnArray is not supported for " + getName(), + "Method deserializeAndInsertFromPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeForCmpAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) override { throw Exception( "Method deserializeForCmpAndInsertFromPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) override - { - throw Exception( - "Method deserializeAndInsertFromPos is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); - } - void deserializeForCmpAndInsertFromPosColumnArray( + void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & /* pos */, const IColumn::Offsets & /* array_offsets */, bool /* use_nt_align_buffer */) override { throw Exception( - "Method deserializeForCmpAndInsertFromPosColumnArray is not supported for " + getName(), + "Method deserializeAndInsertFromPosForColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeAndInsertFromPosForColumnArray( + void deserializeForCmpAndInsertFromPosColumnArray( PaddedPODArray & /* pos */, const IColumn::Offsets & /* array_offsets */, bool /* use_nt_align_buffer */) override { throw Exception( - "Method deserializeAndInsertFromPosForColumnArray is not supported for " + getName(), + "Method deserializeForCmpAndInsertFromPosColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } diff --git a/dbms/src/Columns/ColumnDecimal.cpp b/dbms/src/Columns/ColumnDecimal.cpp index 346cc665579..72aded07965 100644 --- a/dbms/src/Columns/ColumnDecimal.cpp +++ b/dbms/src/Columns/ColumnDecimal.cpp @@ -922,14 +922,18 @@ void ColumnDecimal::insertManyFrom(const IColumn & src, size_t position, size } template -void ColumnDecimal::insertSelectiveFrom(const IColumn & src, const IColumn::Offsets & selective_offsets) +void ColumnDecimal::insertSelectiveRangeFrom( + const IColumn & src, + const IColumn::Offsets & selective_offsets, + size_t start, + size_t length) { + RUNTIME_CHECK(selective_offsets.size() >= start + length); const auto & src_data = static_cast(src).data; size_t old_size = data.size(); - size_t to_add_size = selective_offsets.size(); - data.resize(old_size + to_add_size); - for (size_t i = 0; i < to_add_size; ++i) - data[i + old_size] = src_data[selective_offsets[i]]; + data.resize(old_size + length); + for (size_t i = 0; i < length; ++i) + data[i + old_size] = src_data[selective_offsets[i + start]]; } #pragma GCC diagnostic pop diff --git a/dbms/src/Columns/ColumnDecimal.h b/dbms/src/Columns/ColumnDecimal.h index 354519fa379..81146cff245 100644 --- a/dbms/src/Columns/ColumnDecimal.h +++ b/dbms/src/Columns/ColumnDecimal.h @@ -155,7 +155,11 @@ class ColumnDecimal final : public COWPtrHelper::Type>(x)); } void insertRangeFrom(const IColumn & src, size_t start, size_t length) override; void insertManyFrom(const IColumn & src_, size_t position, size_t length) override; - void insertSelectiveFrom(const IColumn & src_, const IColumn::Offsets & selective_offsets) override; + void insertSelectiveRangeFrom( + const IColumn & src_, + const IColumn::Offsets & selective_offsets, + size_t start, + size_t length) override; void popBack(size_t n) override { data.resize_assume_reserved(data.size() - n); } StringRef getRawData() const override @@ -175,21 +179,22 @@ class ColumnDecimal final : public COWPtrHelper & byte_size) const override; void countSerializeByteSizeForCmp( PaddedPODArray & byte_size, const NullMap * nullmap, const TiDB::TiDBCollatorPtr &) const override; - void countSerializeByteSize(PaddedPODArray & byte_size) const override; + void countSerializeByteSizeForColumnArray( + PaddedPODArray & byte_size, + const IColumn::Offsets & array_offsets) const override; void countSerializeByteSizeForCmpColumnArray( PaddedPODArray & byte_size, const IColumn::Offsets & array_offsets, const NullMap * nullmap, const TiDB::TiDBCollatorPtr &) const override; - void countSerializeByteSizeForColumnArray( - PaddedPODArray & byte_size, - const IColumn::Offsets & array_offsets) const override; + void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; void serializeToPosForCmp( PaddedPODArray & pos, size_t start, @@ -198,8 +203,13 @@ class ColumnDecimal final : public COWPtrHelper & pos, size_t start, size_t length, bool has_null) const override; + void serializeToPosForColumnArray( + PaddedPODArray & pos, + size_t start, + size_t length, + bool has_null, + const IColumn::Offsets & array_offsets) const override; void serializeToPosForCmpColumnArray( PaddedPODArray & pos, size_t start, @@ -209,21 +219,15 @@ class ColumnDecimal final : public COWPtrHelper & pos, - size_t start, - size_t length, - bool has_null, - const IColumn::Offsets & array_offsets) const override; - void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; + void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; - void deserializeForCmpAndInsertFromPosColumnArray( + void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, bool use_nt_align_buffer) override; - void deserializeAndInsertFromPosForColumnArray( + void deserializeForCmpAndInsertFromPosColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, bool use_nt_align_buffer) override; diff --git a/dbms/src/Columns/ColumnFixedString.cpp b/dbms/src/Columns/ColumnFixedString.cpp index d02de8dfd0c..760e5aa6d26 100644 --- a/dbms/src/Columns/ColumnFixedString.cpp +++ b/dbms/src/Columns/ColumnFixedString.cpp @@ -89,16 +89,20 @@ void ColumnFixedString::insertManyFrom(const IColumn & src_, size_t position, si memcpySmallAllowReadWriteOverflow15(&chars[i], src_char_ptr, n); } -void ColumnFixedString::insertSelectiveFrom(const IColumn & src_, const Offsets & selective_offsets) +void ColumnFixedString::insertSelectiveRangeFrom( + const IColumn & src_, + const Offsets & selective_offsets, + size_t start, + size_t length) { const auto & src = static_cast(src_); if (n != src.getN()) throw Exception("Size of FixedString doesn't match", ErrorCodes::SIZE_OF_FIXED_STRING_DOESNT_MATCH); size_t old_size = chars.size(); - size_t new_size = old_size + selective_offsets.size() * n; + size_t new_size = old_size + length * n; chars.resize(new_size); const auto & src_chars = src.chars; - for (size_t i = old_size, j = 0; i < new_size; i += n, ++j) + for (size_t i = old_size, j = start; i < new_size; i += n, ++j) memcpySmallAllowReadWriteOverflow15(&chars[i], &src_chars[selective_offsets[j] * n], n); } diff --git a/dbms/src/Columns/ColumnFixedString.h b/dbms/src/Columns/ColumnFixedString.h index 7b6dd4d42b6..a437c8c617b 100644 --- a/dbms/src/Columns/ColumnFixedString.h +++ b/dbms/src/Columns/ColumnFixedString.h @@ -105,7 +105,8 @@ class ColumnFixedString final : public COWPtrHelper void insertManyFrom(const IColumn & src_, size_t position, size_t length) override; - void insertSelectiveFrom(const IColumn & src_, const Offsets & selective_offsets) override; + void insertSelectiveRangeFrom(const IColumn & src_, const Offsets & selective_offsets, size_t start, size_t length) + override; void insertData(const char * pos, size_t length) override; @@ -123,6 +124,10 @@ class ColumnFixedString final : public COWPtrHelper const char * deserializeAndInsertFromArena(const char * pos, const TiDB::TiDBCollatorPtr &) override; + void countSerializeByteSize(PaddedPODArray & byte_size) const override + { + countSerializeByteSizeImpl(byte_size); + } void countSerializeByteSizeForCmp( PaddedPODArray & byte_size, const NullMap * /*nullmap*/, @@ -135,20 +140,17 @@ class ColumnFixedString final : public COWPtrHelper getName()); countSerializeByteSizeImpl(byte_size); } - void countSerializeByteSize(PaddedPODArray & byte_size) const override - { - countSerializeByteSizeImpl(byte_size); - } + void countSerializeByteSizeForColumnArray( + PaddedPODArray & byte_size, + const IColumn::Offsets & array_offsets) const override; void countSerializeByteSizeForCmpColumnArray( PaddedPODArray & byte_size, const IColumn::Offsets & array_offsets, const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator) const override; - void countSerializeByteSizeForColumnArray( - PaddedPODArray & byte_size, - const IColumn::Offsets & array_offsets) const override; + void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; void serializeToPosForCmp( PaddedPODArray & pos, size_t start, @@ -157,8 +159,13 @@ class ColumnFixedString final : public COWPtrHelper const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, String *) const override; - void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; + void serializeToPosForColumnArray( + PaddedPODArray & pos, + size_t start, + size_t length, + bool has_null, + const IColumn::Offsets & array_offsets) const override; void serializeToPosForCmpColumnArray( PaddedPODArray & pos, size_t start, @@ -168,19 +175,17 @@ class ColumnFixedString final : public COWPtrHelper const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator, String *) const override; - void serializeToPosForColumnArray( - PaddedPODArray & pos, - size_t start, - size_t length, - bool has_null, - const IColumn::Offsets & array_offsets) const override; + void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override { deserializeAndInsertFromPos(pos, use_nt_align_buffer); } - void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; + void deserializeAndInsertFromPosForColumnArray( + PaddedPODArray & pos, + const IColumn::Offsets & array_offsets, + bool use_nt_align_buffer) override; void deserializeForCmpAndInsertFromPosColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, @@ -188,10 +193,6 @@ class ColumnFixedString final : public COWPtrHelper { deserializeAndInsertFromPosForColumnArray(pos, array_offsets, use_nt_align_buffer); } - void deserializeAndInsertFromPosForColumnArray( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer) override; void flushNTAlignBuffer() override {} diff --git a/dbms/src/Columns/ColumnFunction.h b/dbms/src/Columns/ColumnFunction.h index 7813d2a0c13..ef712883144 100644 --- a/dbms/src/Columns/ColumnFunction.h +++ b/dbms/src/Columns/ColumnFunction.h @@ -99,7 +99,7 @@ class ColumnFunction final : public COWPtrHelper throw Exception("Cannot insert into " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void insertSelectiveFrom(const IColumn &, const Offsets &) override + void insertSelectiveRangeFrom(const IColumn &, const Offsets &, size_t, size_t) override { throw Exception("Cannot insert into " + getName(), ErrorCodes::NOT_IMPLEMENTED); } @@ -120,6 +120,10 @@ class ColumnFunction final : public COWPtrHelper throw Exception("Cannot deserialize to " + getName(), ErrorCodes::NOT_IMPLEMENTED); } + void countSerializeByteSize(PaddedPODArray & /* byte_size */) const override + { + throw Exception("Method countSerializeByteSize is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + } void countSerializeByteSizeForCmp( PaddedPODArray & /* byte_size */, const NullMap * /*nullmap*/, @@ -129,11 +133,15 @@ class ColumnFunction final : public COWPtrHelper "Method countSerializeByteSizeForCmp is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void countSerializeByteSize(PaddedPODArray & /* byte_size */) const override + + void countSerializeByteSizeForColumnArray( + PaddedPODArray & /* byte_size */, + const IColumn::Offsets & /* offsets */) const override { - throw Exception("Method countSerializeByteSize is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + throw Exception( + "Method countSerializeByteSizeForColumnArray is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); } - void countSerializeByteSizeForCmpColumnArray( PaddedPODArray & /* byte_size */, const IColumn::Offsets & /* offsets */, @@ -144,15 +152,15 @@ class ColumnFunction final : public COWPtrHelper "Method countSerializeByteSizeForCmpColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void countSerializeByteSizeForColumnArray( - PaddedPODArray & /* byte_size */, - const IColumn::Offsets & /* offsets */) const override + + void serializeToPos( + PaddedPODArray & /* pos */, + size_t /* start */, + size_t /* length */, + bool /* has_null */) const override { - throw Exception( - "Method countSerializeByteSizeForColumnArray is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); + throw Exception("Method serializeToPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void serializeToPosForCmp( PaddedPODArray & /* pos */, size_t /* start */, @@ -164,15 +172,18 @@ class ColumnFunction final : public COWPtrHelper { throw Exception("Method serializeToPosForCmp is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void serializeToPos( + + void serializeToPosForColumnArray( PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */) const override + bool /* has_null */, + const IColumn::Offsets & /* array_offsets */) const override { - throw Exception("Method serializeToPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + throw Exception( + "Method serializeToPosForColumnArray is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); } - void serializeToPosForCmpColumnArray( PaddedPODArray & /* pos */, size_t /* start */, @@ -187,47 +198,36 @@ class ColumnFunction final : public COWPtrHelper "Method serializeToPosForCmpColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void serializeToPosForColumnArray( - PaddedPODArray & /* pos */, - size_t /* start */, - size_t /* length */, - bool /* has_null */, - const IColumn::Offsets & /* array_offsets */) const override + + void deserializeAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) override { throw Exception( - "Method serializeToPosForColumnArray is not supported for " + getName(), + "Method deserializeAndInsertFromPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeForCmpAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) override { throw Exception( "Method deserializeForCmpAndInsertFromPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) override - { - throw Exception( - "Method deserializeAndInsertFromPos is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); - } - void deserializeForCmpAndInsertFromPosColumnArray( + void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & /* pos */, const IColumn::Offsets & /* array_offsets */, bool /* use_nt_align_buffer */) override { throw Exception( - "Method deserializeForCmpAndInsertFromPosColumnArray is not supported for " + getName(), + "Method deserializeAndInsertFromPosForColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeAndInsertFromPosForColumnArray( + void deserializeForCmpAndInsertFromPosColumnArray( PaddedPODArray & /* pos */, const IColumn::Offsets & /* array_offsets */, bool /* use_nt_align_buffer */) override { throw Exception( - "Method deserializeAndInsertFromPosForColumnArray is not supported for " + getName(), + "Method deserializeForCmpAndInsertFromPosColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } diff --git a/dbms/src/Columns/ColumnNullable.cpp b/dbms/src/Columns/ColumnNullable.cpp index c407877f465..38a23280f13 100644 --- a/dbms/src/Columns/ColumnNullable.cpp +++ b/dbms/src/Columns/ColumnNullable.cpp @@ -437,17 +437,15 @@ void ColumnNullable::insertManyFrom(const IColumn & src, size_t n, size_t length map.resize_fill(map.size() + length, src_concrete.getNullMapData()[n]); } -void ColumnNullable::insertSelectiveFrom(const IColumn & src, const Offsets & selective_offsets) +void ColumnNullable::insertSelectiveRangeFrom( + const IColumn & src, + const Offsets & selective_offsets, + size_t start, + size_t length) { const auto & src_concrete = static_cast(src); - getNestedColumn().insertSelectiveFrom(src_concrete.getNestedColumn(), selective_offsets); - auto & map = getNullMapData(); - const auto & src_map = src_concrete.getNullMapData(); - size_t old_size = map.size(); - size_t to_add_size = selective_offsets.size(); - map.resize(old_size + to_add_size); - for (size_t i = 0; i < to_add_size; ++i) - map[i + old_size] = src_map[selective_offsets[i]]; + getNestedColumn().insertSelectiveRangeFrom(src_concrete.getNestedColumn(), selective_offsets, start, length); + getNullMapColumn().insertSelectiveRangeFrom(src_concrete.getNullMapColumn(), selective_offsets, start, length); } void ColumnNullable::popBack(size_t n) diff --git a/dbms/src/Columns/ColumnNullable.h b/dbms/src/Columns/ColumnNullable.h index dd074ac8cb1..d173221f0f1 100644 --- a/dbms/src/Columns/ColumnNullable.h +++ b/dbms/src/Columns/ColumnNullable.h @@ -80,21 +80,22 @@ class ColumnNullable final : public COWPtrHelper String &) const override; const char * deserializeAndInsertFromArena(const char * pos, const TiDB::TiDBCollatorPtr &) override; + void countSerializeByteSize(PaddedPODArray & byte_size) const override; void countSerializeByteSizeForCmp( PaddedPODArray & byte_size, const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator) const override; - void countSerializeByteSize(PaddedPODArray & byte_size) const override; + void countSerializeByteSizeForColumnArray( + PaddedPODArray & byte_size, + const IColumn::Offsets & array_offsets) const override; void countSerializeByteSizeForCmpColumnArray( PaddedPODArray & byte_size, const IColumn::Offsets & array_offsets, const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator) const override; - void countSerializeByteSizeForColumnArray( - PaddedPODArray & byte_size, - const IColumn::Offsets & array_offsets) const override; + void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; void serializeToPosForCmp( PaddedPODArray & pos, size_t start, @@ -103,8 +104,13 @@ class ColumnNullable final : public COWPtrHelper const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const override; - void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; + void serializeToPosForColumnArray( + PaddedPODArray & pos, + size_t start, + size_t length, + bool has_null, + const IColumn::Offsets & array_offsets) const override; void serializeToPosForCmpColumnArray( PaddedPODArray & pos, size_t start, @@ -114,21 +120,15 @@ class ColumnNullable final : public COWPtrHelper const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const override; - void serializeToPosForColumnArray( - PaddedPODArray & pos, - size_t start, - size_t length, - bool has_null, - const IColumn::Offsets & array_offsets) const override; - void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; + void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; - void deserializeForCmpAndInsertFromPosColumnArray( + void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, bool use_nt_align_buffer) override; - void deserializeAndInsertFromPosForColumnArray( + void deserializeForCmpAndInsertFromPosColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, bool use_nt_align_buffer) override; @@ -140,7 +140,8 @@ class ColumnNullable final : public COWPtrHelper void insert(const Field & x) override; void insertFrom(const IColumn & src, size_t n) override; void insertManyFrom(const IColumn & src, size_t n, size_t length) override; - void insertSelectiveFrom(const IColumn & src, const Offsets & selective_offsets) override; + void insertSelectiveRangeFrom(const IColumn & src, const Offsets & selective_offsets, size_t start, size_t length) + override; void insertDefault() override { diff --git a/dbms/src/Columns/ColumnString.h b/dbms/src/Columns/ColumnString.h index f6731c2f1d8..55c160236a8 100644 --- a/dbms/src/Columns/ColumnString.h +++ b/dbms/src/Columns/ColumnString.h @@ -228,12 +228,14 @@ class ColumnString final : public COWPtrHelper insertFromImpl(src, position); } - void insertSelectiveFrom(const IColumn & src_, const Offsets & selective_offsets) override + void insertSelectiveRangeFrom(const IColumn & src_, const Offsets & selective_offsets, size_t start, size_t length) + override { + RUNTIME_CHECK(selective_offsets.size() >= start + length); const auto & src = static_cast(src_); - offsets.reserve(offsets.size() + selective_offsets.size()); - for (auto position : selective_offsets) - insertFromImpl(src, position); + offsets.reserve(offsets.size() + length); + for (size_t i = start; i < start + length; ++i) + insertFromImpl(src, selective_offsets[i]); } template @@ -310,21 +312,22 @@ class ColumnString final : public COWPtrHelper return pos + string_size; } + void countSerializeByteSize(PaddedPODArray & byte_size) const override; void countSerializeByteSizeForCmp( PaddedPODArray & byte_size, const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator) const override; - void countSerializeByteSize(PaddedPODArray & byte_size) const override; + void countSerializeByteSizeForColumnArray( + PaddedPODArray & byte_size, + const IColumn::Offsets & array_offsets) const override; void countSerializeByteSizeForCmpColumnArray( PaddedPODArray & byte_size, const IColumn::Offsets & array_offsets, const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator) const override; - void countSerializeByteSizeForColumnArray( - PaddedPODArray & byte_size, - const IColumn::Offsets & array_offsets) const override; + void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; void serializeToPosForCmp( PaddedPODArray & pos, size_t start, @@ -333,8 +336,13 @@ class ColumnString final : public COWPtrHelper const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const override; - void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; + void serializeToPosForColumnArray( + PaddedPODArray & pos, + size_t start, + size_t length, + bool has_null, + const IColumn::Offsets & array_offsets) const override; void serializeToPosForCmpColumnArray( PaddedPODArray & pos, size_t start, @@ -344,21 +352,15 @@ class ColumnString final : public COWPtrHelper const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const override; - void serializeToPosForColumnArray( - PaddedPODArray & pos, - size_t start, - size_t length, - bool has_null, - const IColumn::Offsets & array_offsets) const override; - void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; + void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; - void deserializeForCmpAndInsertFromPosColumnArray( + void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, bool use_nt_align_buffer) override; - void deserializeAndInsertFromPosForColumnArray( + void deserializeForCmpAndInsertFromPosColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, bool use_nt_align_buffer) override; diff --git a/dbms/src/Columns/ColumnTuple.h b/dbms/src/Columns/ColumnTuple.h index a1a37c6a1cf..dac317a4f0d 100644 --- a/dbms/src/Columns/ColumnTuple.h +++ b/dbms/src/Columns/ColumnTuple.h @@ -72,10 +72,12 @@ class ColumnTuple final : public COWPtrHelper insertFrom(src_, n); } - void insertSelectiveFrom(const IColumn & src_, const Offsets & selective_offsets) override + void insertSelectiveRangeFrom(const IColumn & src_, const Offsets & selective_offsets, size_t start, size_t length) + override { - for (auto position : selective_offsets) - insertFrom(src_, position); + RUNTIME_CHECK(selective_offsets.size() >= start + length); + for (size_t i = start; i < start + length; ++i) + insertFrom(src_, selective_offsets[i]); } void insertDefault() override; @@ -95,6 +97,11 @@ class ColumnTuple final : public COWPtrHelper String &) const override; const char * deserializeAndInsertFromArena(const char * pos, const TiDB::TiDBCollatorPtr &) override; + void countSerializeByteSize(PaddedPODArray & byte_size) const override + { + for (const auto & column : columns) + column->countSerializeByteSize(byte_size); + } void countSerializeByteSizeForCmp( PaddedPODArray & byte_size, const NullMap * nullmap, @@ -103,12 +110,14 @@ class ColumnTuple final : public COWPtrHelper for (const auto & column : columns) column->countSerializeByteSizeForCmp(byte_size, nullmap, collator); } - void countSerializeByteSize(PaddedPODArray & byte_size) const override + + void countSerializeByteSizeForColumnArray( + PaddedPODArray & byte_size, + const IColumn::Offsets & array_offsets) const override { for (const auto & column : columns) - column->countSerializeByteSize(byte_size); + column->countSerializeByteSizeForColumnArray(byte_size, array_offsets); } - void countSerializeByteSizeForCmpColumnArray( PaddedPODArray & byte_size, const IColumn::Offsets & array_offsets, @@ -118,14 +127,12 @@ class ColumnTuple final : public COWPtrHelper for (const auto & column : columns) column->countSerializeByteSizeForCmpColumnArray(byte_size, array_offsets, nullmap, collator); } - void countSerializeByteSizeForColumnArray( - PaddedPODArray & byte_size, - const IColumn::Offsets & array_offsets) const override + + void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override { for (const auto & column : columns) - column->countSerializeByteSizeForColumnArray(byte_size, array_offsets); + column->serializeToPos(pos, start, length, has_null); } - void serializeToPosForCmp( PaddedPODArray & pos, size_t start, @@ -138,12 +145,17 @@ class ColumnTuple final : public COWPtrHelper for (const auto & column : columns) column->serializeToPosForCmp(pos, start, length, has_null, nullmap, collator, sort_key_container); } - void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override + + void serializeToPosForColumnArray( + PaddedPODArray & pos, + size_t start, + size_t length, + bool has_null, + const IColumn::Offsets & array_offsets) const override { for (const auto & column : columns) - column->serializeToPos(pos, start, length, has_null); + column->serializeToPosForColumnArray(pos, start, length, has_null, array_offsets); } - void serializeToPosForCmpColumnArray( PaddedPODArray & pos, size_t start, @@ -165,46 +177,36 @@ class ColumnTuple final : public COWPtrHelper collator, sort_key_container); } - void serializeToPosForColumnArray( - PaddedPODArray & pos, - size_t start, - size_t length, - bool has_null, - const IColumn::Offsets & array_offsets) const override - { - for (const auto & column : columns) - column->serializeToPosForColumnArray(pos, start, length, has_null, array_offsets); - } - void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override + void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override { for (auto & column : columns) - column->assumeMutableRef().deserializeForCmpAndInsertFromPos(pos, use_nt_align_buffer); + column->assumeMutableRef().deserializeAndInsertFromPos(pos, use_nt_align_buffer); } - void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override + void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override { for (auto & column : columns) - column->assumeMutableRef().deserializeAndInsertFromPos(pos, use_nt_align_buffer); + column->assumeMutableRef().deserializeForCmpAndInsertFromPos(pos, use_nt_align_buffer); } - void deserializeForCmpAndInsertFromPosColumnArray( + void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, bool use_nt_align_buffer) override { for (auto & column : columns) - column->assumeMutableRef().deserializeForCmpAndInsertFromPosColumnArray( + column->assumeMutableRef().deserializeAndInsertFromPosForColumnArray( pos, array_offsets, use_nt_align_buffer); } - void deserializeAndInsertFromPosForColumnArray( + void deserializeForCmpAndInsertFromPosColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, bool use_nt_align_buffer) override { for (auto & column : columns) - column->assumeMutableRef().deserializeAndInsertFromPosForColumnArray( + column->assumeMutableRef().deserializeForCmpAndInsertFromPosColumnArray( pos, array_offsets, use_nt_align_buffer); diff --git a/dbms/src/Columns/ColumnVector.h b/dbms/src/Columns/ColumnVector.h index a78f0539517..17f76f2e135 100644 --- a/dbms/src/Columns/ColumnVector.h +++ b/dbms/src/Columns/ColumnVector.h @@ -236,14 +236,18 @@ class ColumnVector final : public COWPtrHelper= start + length); const auto & src_container = static_cast(src).getData(); size_t old_size = data.size(); - size_t to_add_size = selective_offsets.size(); - data.resize(old_size + to_add_size); - for (size_t i = 0; i < to_add_size; ++i) - data[i + old_size] = src_container[selective_offsets[i]]; + data.resize(old_size + length); + for (size_t i = 0; i < length; ++i) + data[i + old_size] = src_container[selective_offsets[start + i]]; } void insertMany(const Field & field, size_t length) override @@ -334,6 +338,7 @@ class ColumnVector final : public COWPtrHelper & byte_size) const override; void countSerializeByteSizeForCmp( PaddedPODArray & byte_size, const NullMap * /*nullmap*/, @@ -341,17 +346,17 @@ class ColumnVector final : public COWPtrHelper & byte_size) const override; + void countSerializeByteSizeForColumnArray( + PaddedPODArray & byte_size, + const IColumn::Offsets & array_offsets) const override; void countSerializeByteSizeForCmpColumnArray( PaddedPODArray & byte_size, const IColumn::Offsets & array_offsets, const NullMap * nullmap, const TiDB::TiDBCollatorPtr &) const override; - void countSerializeByteSizeForColumnArray( - PaddedPODArray & byte_size, - const IColumn::Offsets & array_offsets) const override; + void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; void serializeToPosForCmp( PaddedPODArray & pos, size_t start, @@ -360,8 +365,13 @@ class ColumnVector final : public COWPtrHelper & pos, size_t start, size_t length, bool has_null) const override; + void serializeToPosForColumnArray( + PaddedPODArray & pos, + size_t start, + size_t length, + bool has_null, + const IColumn::Offsets & array_offsets) const override; void serializeToPosForCmpColumnArray( PaddedPODArray & pos, size_t start, @@ -371,19 +381,17 @@ class ColumnVector final : public COWPtrHelper & pos, - size_t start, - size_t length, - bool has_null, - const IColumn::Offsets & array_offsets) const override; + void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override { deserializeAndInsertFromPos(pos, use_nt_align_buffer); } - void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; + void deserializeAndInsertFromPosForColumnArray( + PaddedPODArray & pos, + const IColumn::Offsets & array_offsets, + bool use_nt_align_buffer) override; void deserializeForCmpAndInsertFromPosColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, @@ -391,10 +399,6 @@ class ColumnVector final : public COWPtrHelper & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer) override; void flushNTAlignBuffer() override; diff --git a/dbms/src/Columns/IColumn.h b/dbms/src/Columns/IColumn.h index ec9bf7d1811..13450adadee 100644 --- a/dbms/src/Columns/IColumn.h +++ b/dbms/src/Columns/IColumn.h @@ -156,7 +156,16 @@ class IColumn : public COWPtr /// Note: the source column and the destination column must be of the same type, can not ColumnXXX->insertSelectiveFrom(ConstColumnXXX, ...) using Offset = UInt64; using Offsets = PaddedPODArray; - virtual void insertSelectiveFrom(const IColumn & src, const Offsets & selective_offsets) = 0; + void insertSelectiveFrom(const IColumn & src, const Offsets & selective_offsets) + { + insertSelectiveRangeFrom(src, selective_offsets, 0, selective_offsets.size()); + } + virtual void insertSelectiveRangeFrom( + const IColumn & src, + const Offsets & selective_offsets, + size_t start, + size_t length) + = 0; /// Appends one field multiple times. Can be optimized in inherited classes. virtual void insertMany(const Field & field, size_t length) @@ -246,26 +255,26 @@ class IColumn : public COWPtr /// Count the serialize byte size and added to the byte_size. /// The byte_size.size() must be equal to the column size. + virtual void countSerializeByteSize(PaddedPODArray & /* byte_size */) const = 0; virtual void countSerializeByteSizeForCmp( PaddedPODArray & /* byte_size */, - const NullMap * /*nullmap*/, + const NullMap * /* nullmap */, const TiDB::TiDBCollatorPtr & /* collator */) const = 0; - virtual void countSerializeByteSize(PaddedPODArray & /* byte_size */) const = 0; /// Count the serialize byte size and added to the byte_size called by ColumnArray. /// array_offsets is the offsets of ColumnArray. /// The byte_size.size() must be equal to the array_offsets.size(). + virtual void countSerializeByteSizeForColumnArray( + PaddedPODArray & /* byte_size */, + const Offsets & /* array_offsets */) const + = 0; virtual void countSerializeByteSizeForCmpColumnArray( PaddedPODArray & /* byte_size */, const Offsets & /* array_offsets */, - const NullMap * /*nullmap*/, + const NullMap * /* nullmap */, const TiDB::TiDBCollatorPtr & /* collator */) const = 0; - virtual void countSerializeByteSizeForColumnArray( - PaddedPODArray & /* byte_size */, - const Offsets & /* array_offsets */) const - = 0; /// Serialize data of column from start to start + length into pointer of pos and forward each pos[i] to the end of /// serialized data. @@ -286,7 +295,7 @@ class IColumn : public COWPtr size_t /* start */, size_t /* length */, bool /* has_null */, - const NullMap * /*nullmap*/, + const NullMap * /* nullmap */, const TiDB::TiDBCollatorPtr & /* collator */, String * /* sort_key_container */) const = 0; @@ -307,7 +316,7 @@ class IColumn : public COWPtr size_t /* start */, size_t /* length */, bool /* has_null */, - const NullMap * /*nullmap*/, + const NullMap * /* nullmap */, const Offsets & /* array_offsets */, const TiDB::TiDBCollatorPtr & /* collator */, String * /* sort_key_container */) const @@ -331,20 +340,20 @@ class IColumn : public COWPtr /// } /// for (auto & column_ptr : mutable_columns) /// column_ptr->flushNTAlignBuffer(); + virtual void deserializeAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) = 0; virtual void deserializeForCmpAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) = 0; - virtual void deserializeAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) = 0; /// Deserialize and insert data from pos and forward each pos[i] to the end of serialized data. /// Only called by ColumnArray. /// array_offsets is the offsets of ColumnArray. /// The last pos.size() elements of array_offsets can be used to get the length of elements from each pos. - virtual void deserializeForCmpAndInsertFromPosColumnArray( + virtual void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & /* pos */, const Offsets & /* array_offsets */, bool /* use_nt_align_buffer */) = 0; - virtual void deserializeAndInsertFromPosForColumnArray( + virtual void deserializeForCmpAndInsertFromPosColumnArray( PaddedPODArray & /* pos */, const Offsets & /* array_offsets */, bool /* use_nt_align_buffer */) diff --git a/dbms/src/Columns/IColumnDummy.h b/dbms/src/Columns/IColumnDummy.h index 9d64fb6d821..9ead751cb2e 100644 --- a/dbms/src/Columns/IColumnDummy.h +++ b/dbms/src/Columns/IColumnDummy.h @@ -88,6 +88,10 @@ class IColumnDummy : public IColumn return pos; } + void countSerializeByteSize(PaddedPODArray & /* byte_size */) const override + { + throw Exception("Method countSerializeByteSize is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + } void countSerializeByteSizeForCmp( PaddedPODArray & /* byte_size */, const NullMap * /*nullmap*/, @@ -97,11 +101,15 @@ class IColumnDummy : public IColumn "Method countSerializeByteSizeForCmp is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void countSerializeByteSize(PaddedPODArray & /* byte_size */) const override + + void countSerializeByteSizeForColumnArray( + PaddedPODArray & /* byte_size */, + const IColumn::Offsets & /* array_offsets */) const override { - throw Exception("Method countSerializeByteSize is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + throw Exception( + "Method countSerializeByteSizeForColumnArray is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); } - void countSerializeByteSizeForCmpColumnArray( PaddedPODArray & /* byte_size */, const IColumn::Offsets & /* array_offsets */, @@ -112,15 +120,15 @@ class IColumnDummy : public IColumn "Method countSerializeByteSizeForCmpColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void countSerializeByteSizeForColumnArray( - PaddedPODArray & /* byte_size */, - const IColumn::Offsets & /* array_offsets */) const override + + void serializeToPos( + PaddedPODArray & /* pos */, + size_t /* start */, + size_t /* length */, + bool /* has_null */) const override { - throw Exception( - "Method countSerializeByteSizeForColumnArray is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); + throw Exception("Method serializeToPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void serializeToPosForCmp( PaddedPODArray & /* pos */, size_t /* start */, @@ -132,15 +140,18 @@ class IColumnDummy : public IColumn { throw Exception("Method serializeToPosForCmp is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void serializeToPos( + + void serializeToPosForColumnArray( PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */) const override + bool /* has_null */, + const IColumn::Offsets & /* array_offsets */) const override { - throw Exception("Method serializeToPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + throw Exception( + "Method serializeToPosForColumnArray is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); } - void serializeToPosForCmpColumnArray( PaddedPODArray & /* pos */, size_t /* start */, @@ -155,47 +166,36 @@ class IColumnDummy : public IColumn "Method serializeToPosForCmpColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void serializeToPosForColumnArray( - PaddedPODArray & /* pos */, - size_t /* start */, - size_t /* length */, - bool /* has_null */, - const IColumn::Offsets & /* array_offsets */) const override + + void deserializeAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) override { throw Exception( - "Method serializeToPosForColumnArray is not supported for " + getName(), + "Method deserializeAndInsertFromPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeForCmpAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) override { throw Exception( "Method deserializeForCmpAndInsertFromPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) override - { - throw Exception( - "Method deserializeAndInsertFromPos is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); - } - void deserializeForCmpAndInsertFromPosColumnArray( + void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & /* pos */, const IColumn::Offsets & /* array_offsets */, bool /* use_nt_align_buffer */) override { throw Exception( - "Method deserializeForCmpAndInsertFromPosColumnArray is not supported for " + getName(), + "Method deserializeAndInsertFromPosForColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeAndInsertFromPosForColumnArray( + void deserializeForCmpAndInsertFromPosColumnArray( PaddedPODArray & /* pos */, const IColumn::Offsets & /* array_offsets */, bool /* use_nt_align_buffer */) override { throw Exception( - "Method deserializeAndInsertFromPosForColumnArray is not supported for " + getName(), + "Method deserializeForCmpAndInsertFromPosColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } @@ -217,10 +217,7 @@ class IColumnDummy : public IColumn void insertManyFrom(const IColumn &, size_t, size_t length) override { s += length; } - void insertSelectiveFrom(const IColumn &, const Offsets & selective_offsets) override - { - s += selective_offsets.size(); - } + void insertSelectiveRangeFrom(const IColumn &, const Offsets &, size_t, size_t length) override { s += length; } void insertRangeFrom(const IColumn & /*src*/, size_t /*start*/, size_t length) override { s += length; } diff --git a/dbms/src/Columns/tests/gtest_column_insertFrom.cpp b/dbms/src/Columns/tests/gtest_column_insertFrom.cpp index 235acf51b71..b2e095946a1 100644 --- a/dbms/src/Columns/tests/gtest_column_insertFrom.cpp +++ b/dbms/src/Columns/tests/gtest_column_insertFrom.cpp @@ -83,9 +83,17 @@ class TestColumnInsertFrom : public ::testing::Test selective_offsets.push_back(4); for (size_t position : selective_offsets) cols[0]->insertFrom(*column_ptr, position); + std::vector> range_test = {{0, 1}, {1, 2}, {0, 3}, {2, 1}, {1, 1}}; + for (auto [start, length] : range_test) + { + for (size_t i = start; i < start + length; ++i) + cols[0]->insertFrom(*column_ptr, selective_offsets[i]); + } for (size_t position : selective_offsets) cols[0]->insertFrom(*column_ptr, position); cols[1]->insertSelectiveFrom(*column_ptr, selective_offsets); + for (auto [start, length] : range_test) + cols[1]->insertSelectiveRangeFrom(*column_ptr, selective_offsets, start, length); cols[1]->insertSelectiveFrom(*column_ptr, selective_offsets); { ColumnWithTypeAndName ref(std::move(cols[0]), col_with_type_and_name.type, ""); From 823b8754c450603106a4b10dfd785c0f1683bdf7 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 25 Mar 2025 18:12:02 +0800 Subject: [PATCH 08/84] u Signed-off-by: gengliqi --- dbms/src/Columns/ColumnDecimal.cpp | 352 +++++++++++------------------ dbms/src/Columns/ColumnDecimal.h | 48 ++-- dbms/src/Columns/ColumnString.cpp | 132 ++++------- dbms/src/Columns/ColumnVector.cpp | 46 ++-- dbms/src/Columns/ColumnVector.h | 31 ++- 5 files changed, 231 insertions(+), 378 deletions(-) diff --git a/dbms/src/Columns/ColumnDecimal.cpp b/dbms/src/Columns/ColumnDecimal.cpp index 72aded07965..fdbc93fada0 100644 --- a/dbms/src/Columns/ColumnDecimal.cpp +++ b/dbms/src/Columns/ColumnDecimal.cpp @@ -141,22 +141,14 @@ const char * ColumnDecimal::deserializeAndInsertFromArena(const char * pos, c } } -template -void ColumnDecimal::countSerializeByteSizeForCmp( - PaddedPODArray & byte_size, - const NullMap * nullmap, - const TiDB::TiDBCollatorPtr &) const -{ - if (nullmap != nullptr) - countSerializeByteSizeImpl(byte_size, nullmap); - else - countSerializeByteSizeImpl(byte_size, nullptr); -} - template void ColumnDecimal::countSerializeByteSize(PaddedPODArray & byte_size) const { - countSerializeByteSizeImpl(byte_size, nullptr); + RUNTIME_CHECK_MSG(byte_size.size() == size(), "size of byte_size({}) != column size({})", byte_size.size(), size()); + + size_t size = byte_size.size(); + for (size_t i = 0; i < size; ++i) + byte_size[i] += sizeof(T); } template @@ -167,9 +159,9 @@ void ColumnDecimal::countSerializeByteSizeForCmpColumnArray( const TiDB::TiDBCollatorPtr &) const { if (nullmap != nullptr) - countSerializeByteSizeForColumnArrayImpl(byte_size, array_offsets, nullmap); + countSerializeByteSizeForColumnArrayImpl(byte_size, array_offsets, nullmap); else - countSerializeByteSizeForColumnArrayImpl(byte_size, array_offsets, nullptr); + countSerializeByteSizeForColumnArrayImpl(byte_size, array_offsets, nullptr); } template @@ -177,40 +169,11 @@ void ColumnDecimal::countSerializeByteSizeForColumnArray( PaddedPODArray & byte_size, const IColumn::Offsets & array_offsets) const { - countSerializeByteSizeForColumnArrayImpl(byte_size, array_offsets, nullptr); -} - -template -template -void ColumnDecimal::countSerializeByteSizeImpl(PaddedPODArray & byte_size, const NullMap * nullmap) const -{ - RUNTIME_CHECK_MSG(byte_size.size() == size(), "size of byte_size({}) != column size({})", byte_size.size(), size()); - - size_t size = byte_size.size(); - static constexpr T def_val{}; - for (size_t i = 0; i < size; ++i) - { - if constexpr (compare_semantics && is_Decimal256) - { - if constexpr (has_nullmap) - { - if (DB::isNullAt(*nullmap, i)) - { - byte_size[i] += getDecimal256BytesSize(def_val); - continue; - } - } - byte_size[i] += getDecimal256BytesSize(data[i]); - } - else - { - byte_size[i] += sizeof(T); - } - } + countSerializeByteSizeForColumnArrayImpl(byte_size, array_offsets, nullptr); } template -template +template void ColumnDecimal::countSerializeByteSizeForColumnArrayImpl( PaddedPODArray & byte_size, const IColumn::Offsets & array_offsets, @@ -230,18 +193,7 @@ void ColumnDecimal::countSerializeByteSizeForColumnArrayImpl( if (DB::isNullAt(*nullmap, i)) continue; } - if constexpr (compare_semantics && is_Decimal256) - { - size_t cur_size = 0; - for (size_t j = array_offsets[i - 1]; j < array_offsets[i]; ++j) - cur_size += getDecimal256BytesSize(data[j]); - - byte_size[i] += cur_size; - } - else - { - byte_size[i] += sizeof(T) * (array_offsets[i] - array_offsets[i - 1]); - } + byte_size[i] += sizeof(T) * (array_offsets[i] - array_offsets[i - 1]); } } @@ -255,36 +207,27 @@ void ColumnDecimal::serializeToPosForCmp( const TiDB::TiDBCollatorPtr &, String *) const { +#define CALL(has_null, has_nullmap) \ + { \ + serializeToPosImpl(pos, start, length, nullmap); \ + } + if (has_null) { if (nullmap != nullptr) - serializeToPosImpl( - pos, - start, - length, - nullmap); + CALL(true, true) else - serializeToPosImpl( - pos, - start, - length, - nullptr); + CALL(true, false) } else { if (nullmap != nullptr) - serializeToPosImpl( - pos, - start, - length, - nullmap); + CALL(false, true) else - serializeToPosImpl( - pos, - start, - length, - nullptr); + CALL(false, false) } + +#undef CALL } template @@ -304,6 +247,53 @@ void ColumnDecimal::serializeToPos(PaddedPODArray & pos, size_t start nullptr); } +template +template +void ColumnDecimal::serializeToPosImpl( + PaddedPODArray & pos, + size_t start, + size_t length, + const NullMap * nullmap) const +{ + RUNTIME_CHECK_MSG(length <= pos.size(), "length({}) > size of pos({})", length, pos.size()); + RUNTIME_CHECK_MSG(start + length <= size(), "start({}) + length({}) > size of column({})", start, length, size()); + + RUNTIME_CHECK(!has_nullmap || (nullmap && nullmap->size() == size())); + + static constexpr T def_val{}; + T tmp_val{}; + for (size_t i = 0; i < length; ++i) + { + if constexpr (has_null) + { + if (pos[i] == nullptr) + continue; + } + if constexpr (has_nullmap) + { + if (DB::isNullAt(*nullmap, start + i)) + { + tiflash_compiler_builtin_memcpy(pos[i], &def_val, sizeof(T)); + pos[i] += sizeof(T); + continue; + } + } + + if constexpr (compare_semantics && is_Decimal256) + { + // Clear the value and only set the necessary parts for compare semantic + memset(static_cast(&tmp_val), 0, sizeof(T)); + tmp_val.value.backend().assign(data[start + i].value.backend()); + tiflash_compiler_builtin_memcpy(pos[i], &tmp_val, sizeof(T)); + } + else + { + tiflash_compiler_builtin_memcpy(pos[i], &data[start + i], sizeof(T)); + } + pos[i] += sizeof(T); + } +} + template void ColumnDecimal::serializeToPosForCmpColumnArray( PaddedPODArray & pos, @@ -315,40 +305,27 @@ void ColumnDecimal::serializeToPosForCmpColumnArray( const TiDB::TiDBCollatorPtr &, String *) const { +#define CALL(has_null, has_nullmap) \ + { \ + serializeToPosForColumnArrayImpl(pos, start, length, array_offsets, nullmap); \ + } + if (has_null) { if (nullmap != nullptr) - serializeToPosForColumnArrayImpl( - pos, - start, - length, - array_offsets, - nullmap); + CALL(true, true) else - serializeToPosForColumnArrayImpl( - pos, - start, - length, - array_offsets, - nullptr); + CALL(true, false) } else { if (nullmap != nullptr) - serializeToPosForColumnArrayImpl( - pos, - start, - length, - array_offsets, - nullmap); + CALL(false, true) else - serializeToPosForColumnArrayImpl( - pos, - start, - length, - array_offsets, - nullptr); + CALL(false, false) } + +#undef CALL } template @@ -375,56 +352,6 @@ void ColumnDecimal::serializeToPosForColumnArray( nullptr); } -template -template -void ColumnDecimal::serializeToPosImpl( - PaddedPODArray & pos, - size_t start, - size_t length, - const NullMap * nullmap) const -{ - RUNTIME_CHECK_MSG(length <= pos.size(), "length({}) > size of pos({})", length, pos.size()); - RUNTIME_CHECK_MSG(start + length <= size(), "start({}) + length({}) > size of column({})", start, length, size()); - - RUNTIME_CHECK(!has_nullmap || (nullmap && nullmap->size() == size())); - - static constexpr T def_val{}; - for (size_t i = 0; i < length; ++i) - { - if constexpr (has_null) - { - if (pos[i] == nullptr) - continue; - } - if constexpr (has_nullmap) - { - if (DB::isNullAt(*nullmap, start + i)) - { - if constexpr (compare_semantics && is_Decimal256) - { - pos[i] = serializeDecimal256Helper(pos[i], def_val); - } - else - { - tiflash_compiler_builtin_memcpy(pos[i], &def_val, sizeof(T)); - pos[i] += sizeof(T); - } - continue; - } - } - - if constexpr (compare_semantics && is_Decimal256) - { - pos[i] = serializeDecimal256Helper(pos[i], data[start + i]); - } - else - { - tiflash_compiler_builtin_memcpy(pos[i], &data[start + i], sizeof(T)); - pos[i] += sizeof(T); - } - } -} - template template void ColumnDecimal::serializeToPosForColumnArrayImpl( @@ -449,6 +376,7 @@ void ColumnDecimal::serializeToPosForColumnArrayImpl( RUNTIME_CHECK(!has_nullmap || (nullmap && nullmap->size() == array_offsets.size())); + T tmp_val{}; for (size_t i = 0; i < length; ++i) { if constexpr (has_null) @@ -465,75 +393,51 @@ void ColumnDecimal::serializeToPosForColumnArrayImpl( size_t len = array_offsets[start + i] - array_offsets[start + i - 1]; if constexpr (compare_semantics && is_Decimal256) { + auto * p = pos[i]; for (size_t j = 0; j < len; ++j) - pos[i] = serializeDecimal256Helper(pos[i], data[array_offsets[start + i - 1] + j]); + { + // Clear the value and only set the necessary parts for compare semantic + memset(static_cast(&tmp_val), 0, sizeof(T)); + tmp_val.value.backend().assign(data[start + i].value.backend()); + tiflash_compiler_builtin_memcpy(p, &tmp_val, sizeof(T)); + p += sizeof(T); + } + pos[i] = p; } else { + auto start_idx = array_offsets[start + i - 1]; if (len <= 4) { + auto * p = pos[i]; for (size_t j = 0; j < len; ++j) - tiflash_compiler_builtin_memcpy( - pos[i] + j * sizeof(T), - &data[array_offsets[start + i - 1] + j], - sizeof(T)); + { + tiflash_compiler_builtin_memcpy(p, &data[start_idx + j], sizeof(T)); + p += sizeof(T); + } + pos[i] = p; } else { - inline_memcpy(pos[i], &data[array_offsets[start + i - 1]], len * sizeof(T)); + inline_memcpy(pos[i], &data[start_idx], len * sizeof(T)); + pos[i] += len * sizeof(T); } - pos[i] += len * sizeof(T); } } } template -void ColumnDecimal::deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) -{ - deserializeAndInsertFromPosImpl(pos, use_nt_align_buffer); -} - -template -void ColumnDecimal::deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) -{ - deserializeAndInsertFromPosImpl(pos, use_nt_align_buffer); -} - -template -void ColumnDecimal::deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer) -{ - deserializeAndInsertFromPosForColumnArrayImpl(pos, array_offsets, use_nt_align_buffer); -} - -template -void ColumnDecimal::deserializeAndInsertFromPosForColumnArray( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer) -{ - deserializeAndInsertFromPosForColumnArrayImpl(pos, array_offsets, use_nt_align_buffer); -} - -template -template -void ColumnDecimal::deserializeAndInsertFromPosImpl( +void ColumnDecimal::deserializeAndInsertFromPos( PaddedPODArray & pos, bool use_nt_align_buffer [[maybe_unused]]) { size_t prev_size = data.size(); size_t size = pos.size(); - // is_complex_decimal256 is true means Decimal256 is serialized by [bool, limb_count, n * limb]. - // NT optimization is not implemented for simplicity. - static const bool is_complex_decimal256 = (compare_semantics && is_Decimal256); - #ifdef TIFLASH_ENABLE_AVX_SUPPORT if (use_nt_align_buffer) { - if constexpr ((FULL_VECTOR_SIZE_AVX2 % sizeof(T) == 0) && !is_complex_decimal256) + if constexpr ((FULL_VECTOR_SIZE_AVX2 % sizeof(T) == 0)) { bool is_aligned = reinterpret_cast(&data[prev_size]) % FULL_VECTOR_SIZE_AVX2 == 0; if likely (is_aligned) @@ -610,24 +514,21 @@ void ColumnDecimal::deserializeAndInsertFromPosImpl( #endif data.resize(prev_size + size); - if constexpr (is_complex_decimal256) + for (size_t i = 0; i < size; ++i) { - for (size_t i = 0; i < size; ++i) - pos[i] = const_cast(deserializeDecimal256Helper(data[prev_size + i], pos[i])); - } - else - { - for (size_t i = 0; i < size; ++i) - { - tiflash_compiler_builtin_memcpy(&data[prev_size + i], pos[i], sizeof(T)); - pos[i] += sizeof(T); - } + tiflash_compiler_builtin_memcpy(&data[prev_size + i], pos[i], sizeof(T)); + pos[i] += sizeof(T); } } template -template -void ColumnDecimal::deserializeAndInsertFromPosForColumnArrayImpl( +void ColumnDecimal::deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) +{ + deserializeAndInsertFromPos(pos, use_nt_align_buffer); +} + +template +void ColumnDecimal::deserializeAndInsertFromPosForColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, bool use_nt_align_buffer [[maybe_unused]]) @@ -653,31 +554,34 @@ void ColumnDecimal::deserializeAndInsertFromPosForColumnArrayImpl( for (size_t i = 0; i < size; ++i) { size_t len = array_offsets[start_point + i] - array_offsets[start_point + i - 1]; - if constexpr (compare_semantics && is_Decimal256) + auto start_idx = array_offsets[start_point + i - 1]; + if (len <= 4) { + auto * p = pos[i]; for (size_t j = 0; j < len; ++j) - pos[i] = const_cast( - deserializeDecimal256Helper(data[array_offsets[start_point + i - 1] + j], pos[i])); + { + tiflash_compiler_builtin_memcpy(&data[start_idx + j], p, sizeof(T)); + p += sizeof(T); + } + pos[i] = p; } else { - if (len <= 4) - { - for (size_t j = 0; j < len; ++j) - tiflash_compiler_builtin_memcpy( - &data[array_offsets[start_point + i - 1] + j], - pos[i] + j * sizeof(T), - sizeof(T)); - } - else - { - inline_memcpy(&data[array_offsets[start_point + i - 1]], pos[i], len * sizeof(T)); - } + inline_memcpy(&data[start_idx], pos[i], len * sizeof(T)); pos[i] += len * sizeof(T); } } } +template +void ColumnDecimal::deserializeForCmpAndInsertFromPosColumnArray( + PaddedPODArray & pos, + const IColumn::Offsets & array_offsets, + bool use_nt_align_buffer) +{ + deserializeAndInsertFromPosForColumnArray(pos, array_offsets, use_nt_align_buffer); +} + template void ColumnDecimal::flushNTAlignBuffer() { diff --git a/dbms/src/Columns/ColumnDecimal.h b/dbms/src/Columns/ColumnDecimal.h index 81146cff245..b64ca3d30c4 100644 --- a/dbms/src/Columns/ColumnDecimal.h +++ b/dbms/src/Columns/ColumnDecimal.h @@ -101,33 +101,6 @@ class ColumnDecimal final : public COWPtrHelper - void countSerializeByteSizeImpl(PaddedPODArray & byte_size, const NullMap * nullmap) const; - template - void countSerializeByteSizeForColumnArrayImpl( - PaddedPODArray & byte_size, - const IColumn::Offsets & array_offsets, - const NullMap * nullmap) const; - - template - void serializeToPosImpl(PaddedPODArray & pos, size_t start, size_t length, const NullMap * nullmap) const; - template - void serializeToPosForColumnArrayImpl( - PaddedPODArray & pos, - size_t start, - size_t length, - const IColumn::Offsets & array_offsets, - const NullMap * nullmap) const; - - template - void deserializeAndInsertFromPosImpl(PaddedPODArray & pos, bool use_nt_align_buffer [[maybe_unused]]); - - template - void deserializeAndInsertFromPosForColumnArrayImpl( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer [[maybe_unused]]); - public: const char * getFamilyName() const override { return TypeName::get(); } @@ -182,8 +155,11 @@ class ColumnDecimal final : public COWPtrHelper & byte_size) const override; void countSerializeByteSizeForCmp( PaddedPODArray & byte_size, - const NullMap * nullmap, - const TiDB::TiDBCollatorPtr &) const override; + const NullMap * /* nullmap */, + const TiDB::TiDBCollatorPtr &) const override + { + countSerializeByteSize(byte_size); + } void countSerializeByteSizeForColumnArray( PaddedPODArray & byte_size, @@ -193,6 +169,11 @@ class ColumnDecimal final : public COWPtrHelper + void countSerializeByteSizeForColumnArrayImpl( + PaddedPODArray & byte_size, + const IColumn::Offsets & array_offsets, + const NullMap * nullmap) const; void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; void serializeToPosForCmp( @@ -203,6 +184,8 @@ class ColumnDecimal final : public COWPtrHelper + void serializeToPosImpl(PaddedPODArray & pos, size_t start, size_t length, const NullMap * nullmap) const; void serializeToPosForColumnArray( PaddedPODArray & pos, @@ -219,6 +202,13 @@ class ColumnDecimal final : public COWPtrHelper + void serializeToPosForColumnArrayImpl( + PaddedPODArray & pos, + size_t start, + size_t length, + const IColumn::Offsets & array_offsets, + const NullMap * nullmap) const; void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; diff --git a/dbms/src/Columns/ColumnString.cpp b/dbms/src/Columns/ColumnString.cpp index e22390ec00e..c7349db4e18 100644 --- a/dbms/src/Columns/ColumnString.cpp +++ b/dbms/src/Columns/ColumnString.cpp @@ -1142,65 +1142,47 @@ void ColumnString::serializeToPosForColumnArrayImpl( if (DB::isNullAt(*nullmap, start + i)) continue; } - if constexpr (compare_semantics) + if constexpr (compare_semantics && need_decode_collator) { + auto * size_pos = pos[i]; + auto * p = pos[i]; + p += (array_offsets[start + i] - array_offsets[start + i - 1]) * sizeof(UInt32); for (size_t j = array_offsets[start + i - 1]; j < array_offsets[start + i]; ++j) { UInt32 str_size = sizeAt(j); const void * src = &chars[offsetAt(j)]; - if constexpr (need_decode_collator) - { - auto sort_key = derived_collator->sortKey( - reinterpret_cast(src), - str_size - 1, - *sort_key_container); - // For terminating zero. - str_size = sort_key.size + 1; - - tiflash_compiler_builtin_memcpy(pos[i], &str_size, sizeof(UInt32)); - pos[i] += sizeof(UInt32); - inline_memcpy(pos[i], sort_key.data, sort_key.size); - pos[i] += sort_key.size; - *(pos[i]) = '\0'; - pos[i] += 1; - } - else - { - tiflash_compiler_builtin_memcpy(pos[i], &str_size, sizeof(UInt32)); - pos[i] += sizeof(UInt32); - inline_memcpy(pos[i], src, str_size); - pos[i] += str_size; - } + auto sort_key + = derived_collator->sortKey(reinterpret_cast(src), str_size - 1, *sort_key_container); + // For terminating zero. + str_size = sort_key.size + 1; + + tiflash_compiler_builtin_memcpy(size_pos, &str_size, sizeof(UInt32)); + size_pos += sizeof(UInt32); + inline_memcpy(p, sort_key.data, sort_key.size); + p += sort_key.size; + *p = '\0'; + p += 1; } + pos[i] = p; } else { + auto * p = pos[i]; for (size_t j = array_offsets[start + i - 1]; j < array_offsets[start + i]; ++j) { UInt32 str_size = sizeAt(j); - tiflash_compiler_builtin_memcpy(pos[i], &str_size, sizeof(UInt32)); - pos[i] += sizeof(UInt32); + tiflash_compiler_builtin_memcpy(p, &str_size, sizeof(UInt32)); + p += sizeof(UInt32); } size_t strs_size = offsetAt(array_offsets[start + i]) - offsetAt(array_offsets[start + i - 1]); - inline_memcpy(pos[i], &chars[offsetAt(array_offsets[start + i - 1])], strs_size); - pos[i] += strs_size; + inline_memcpy(p, &chars[offsetAt(array_offsets[start + i - 1])], strs_size); + p += strs_size; + pos[i] = p; } } } -void ColumnString::deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) -{ - deserializeAndInsertFromPosImpl(pos, use_nt_align_buffer); -} - -void ColumnString::deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) -{ - deserializeAndInsertFromPosImpl(pos, use_nt_align_buffer); -} - -void ColumnString::deserializeAndInsertFromPosImpl( - PaddedPODArray & pos, - bool use_nt_align_buffer [[maybe_unused]]) +void ColumnString::deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer [[maybe_unused]]) { size_t prev_size = offsets.size(); size_t char_size = chars.size(); @@ -1307,24 +1289,12 @@ void ColumnString::deserializeAndInsertFromPosImpl( } } -void ColumnString::deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer) +void ColumnString::deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) { - deserializeAndInsertFromPosForColumnArrayImpl(pos, array_offsets, use_nt_align_buffer); + deserializeAndInsertFromPos(pos, use_nt_align_buffer); } void ColumnString::deserializeAndInsertFromPosForColumnArray( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer) -{ - deserializeAndInsertFromPosForColumnArrayImpl(pos, array_offsets, use_nt_align_buffer); -} - -template -void ColumnString::deserializeAndInsertFromPosForColumnArrayImpl( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, bool use_nt_align_buffer [[maybe_unused]]) @@ -1348,45 +1318,31 @@ void ColumnString::deserializeAndInsertFromPosForColumnArrayImpl( size_t size = pos.size(); size_t char_size = chars.size(); - if constexpr (compare_semantics) - { - for (size_t i = 0; i < size; ++i) - { - for (size_t j = array_offsets[start_point + i - 1]; j < array_offsets[start_point + i]; ++j) - { - UInt32 str_size; - tiflash_compiler_builtin_memcpy(&str_size, pos[i], sizeof(UInt32)); - pos[i] += sizeof(UInt32); - - chars.resize(char_size + str_size); - inline_memcpy(&chars[char_size], pos[i], str_size); - - char_size += str_size; - offsets[j] = char_size; - pos[i] += str_size; - } - } - } - else + for (size_t i = 0; i < size; ++i) { - for (size_t i = 0; i < size; ++i) + size_t prev_char_size = char_size; + for (size_t j = array_offsets[start_point + i - 1]; j < array_offsets[start_point + i]; ++j) { - size_t prev_char_size = char_size; - for (size_t j = array_offsets[start_point + i - 1]; j < array_offsets[start_point + i]; ++j) - { - UInt32 str_size; - tiflash_compiler_builtin_memcpy(&str_size, pos[i], sizeof(UInt32)); - pos[i] += sizeof(UInt32); - char_size += str_size; - offsets[j] = char_size; - } - chars.resize(char_size); - inline_memcpy(&chars[prev_char_size], pos[i], char_size - prev_char_size); - pos[i] += char_size - prev_char_size; + UInt32 str_size; + tiflash_compiler_builtin_memcpy(&str_size, pos[i], sizeof(UInt32)); + pos[i] += sizeof(UInt32); + char_size += str_size; + offsets[j] = char_size; } + chars.resize(char_size); + inline_memcpy(&chars[prev_char_size], pos[i], char_size - prev_char_size); + pos[i] += char_size - prev_char_size; } } +void ColumnString::deserializeForCmpAndInsertFromPosColumnArray( + PaddedPODArray & pos, + const IColumn::Offsets & array_offsets, + bool use_nt_align_buffer) +{ + deserializeAndInsertFromPosForColumnArrayImpl(pos, array_offsets, use_nt_align_buffer); +} + void ColumnString::flushNTAlignBuffer() { #ifdef TIFLASH_ENABLE_AVX_SUPPORT diff --git a/dbms/src/Columns/ColumnVector.cpp b/dbms/src/Columns/ColumnVector.cpp index 1218a5cef20..fc39ae1c852 100644 --- a/dbms/src/Columns/ColumnVector.cpp +++ b/dbms/src/Columns/ColumnVector.cpp @@ -68,6 +68,14 @@ void ColumnVector::countSerializeByteSize(PaddedPODArray & byte_size) byte_size[i] += sizeof(T); } +template +void ColumnVector::countSerializeByteSizeForColumnArray( + PaddedPODArray & byte_size, + const IColumn::Offsets & array_offsets) const +{ + countSerializeByteSizeForColumnArrayImpl(byte_size, array_offsets, nullptr); +} + template void ColumnVector::countSerializeByteSizeForCmpColumnArray( PaddedPODArray & byte_size, @@ -81,14 +89,6 @@ void ColumnVector::countSerializeByteSizeForCmpColumnArray( countSerializeByteSizeForColumnArrayImpl(byte_size, array_offsets, nullptr); } -template -void ColumnVector::countSerializeByteSizeForColumnArray( - PaddedPODArray & byte_size, - const IColumn::Offsets & array_offsets) const -{ - countSerializeByteSizeForColumnArrayImpl(byte_size, array_offsets, nullptr); -} - template template void ColumnVector::countSerializeByteSizeForColumnArrayImpl( @@ -262,19 +262,22 @@ void ColumnVector::serializeToPosForColumnArrayImpl( continue; } size_t len = array_offsets[start + i] - array_offsets[start + i - 1]; + auto start_idx = array_offsets[start + i - 1]; if (len <= 4) { + auto * p = pos[i]; for (size_t j = 0; j < len; ++j) - tiflash_compiler_builtin_memcpy( - pos[i] + j * sizeof(T), - &data[array_offsets[start + i - 1] + j], - sizeof(T)); + { + tiflash_compiler_builtin_memcpy(p, &data[start_idx + j], sizeof(T)); + p += sizeof(T); + } + pos[i] = p; } else { - inline_memcpy(pos[i], &data[array_offsets[start + i - 1]], len * sizeof(T)); + inline_memcpy(pos[i], &data[start_idx], len * sizeof(T)); + pos[i] += len * sizeof(T); } - pos[i] += len * sizeof(T); } } @@ -400,19 +403,22 @@ void ColumnVector::deserializeAndInsertFromPosForColumnArray( for (size_t i = 0; i < size; ++i) { size_t len = array_offsets[start_point + i] - array_offsets[start_point + i - 1]; + auto start_idx = array_offsets[start_point + i - 1]; if (len <= 4) { + auto * p = pos[i]; for (size_t j = 0; j < len; ++j) - tiflash_compiler_builtin_memcpy( - &data[array_offsets[start_point + i - 1] + j], - pos[i] + j * sizeof(T), - sizeof(T)); + { + tiflash_compiler_builtin_memcpy(&data[start_idx + j], p, sizeof(T)); + p += sizeof(T); + } + pos[i] = p; } else { - inline_memcpy(&data[array_offsets[start_point + i - 1]], pos[i], len * sizeof(T)); + inline_memcpy(&data[start_idx], pos[i], len * sizeof(T)); + pos[i] += len * sizeof(T); } - pos[i] += len * sizeof(T); } } diff --git a/dbms/src/Columns/ColumnVector.h b/dbms/src/Columns/ColumnVector.h index 17f76f2e135..096ad6b9dbf 100644 --- a/dbms/src/Columns/ColumnVector.h +++ b/dbms/src/Columns/ColumnVector.h @@ -198,23 +198,6 @@ class ColumnVector final : public COWPtrHelper - void countSerializeByteSizeForColumnArrayImpl( - PaddedPODArray & byte_size, - const IColumn::Offsets & array_offsets, - const NullMap * nullmap) const; - - template - void serializeToPosImpl(PaddedPODArray & pos, size_t start, size_t length, const NullMap * nullmap) const; - - template - void serializeToPosForColumnArrayImpl( - PaddedPODArray & pos, - size_t start, - size_t length, - const IColumn::Offsets & array_offsets, - const NullMap * nullmap) const; - public: bool isNumeric() const override { return is_arithmetic_v; } @@ -355,6 +338,11 @@ class ColumnVector final : public COWPtrHelper + void countSerializeByteSizeForColumnArrayImpl( + PaddedPODArray & byte_size, + const IColumn::Offsets & array_offsets, + const NullMap * nullmap) const; void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; void serializeToPosForCmp( @@ -365,6 +353,8 @@ class ColumnVector final : public COWPtrHelper + void serializeToPosImpl(PaddedPODArray & pos, size_t start, size_t length, const NullMap * nullmap) const; void serializeToPosForColumnArray( PaddedPODArray & pos, @@ -381,6 +371,13 @@ class ColumnVector final : public COWPtrHelper + void serializeToPosForColumnArrayImpl( + PaddedPODArray & pos, + size_t start, + size_t length, + const IColumn::Offsets & array_offsets, + const NullMap * nullmap) const; void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override From 41bb1bcfa1a3468bb4be582de42197e544a9b9ce Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 25 Mar 2025 23:24:55 +0800 Subject: [PATCH 09/84] u Signed-off-by: gengliqi --- dbms/src/Columns/ColumnAggregateFunction.h | 15 - dbms/src/Columns/ColumnArray.cpp | 16 +- dbms/src/Columns/ColumnArray.h | 10 - dbms/src/Columns/ColumnConst.h | 15 - dbms/src/Columns/ColumnDecimal.cpp | 33 +- dbms/src/Columns/ColumnDecimal.h | 5 - dbms/src/Columns/ColumnFixedString.h | 11 - dbms/src/Columns/ColumnFunction.h | 15 - dbms/src/Columns/ColumnNullable.cpp | 63 +-- dbms/src/Columns/ColumnNullable.h | 5 - dbms/src/Columns/ColumnString.cpp | 477 ++++++------------ dbms/src/Columns/ColumnString.h | 93 +--- dbms/src/Columns/ColumnTuple.h | 16 - dbms/src/Columns/ColumnVector.h | 11 - dbms/src/Columns/IColumn.h | 7 - dbms/src/Columns/IColumnDummy.h | 15 - .../gtest_column_serialize_deserialize.cpp | 27 +- dbms/src/Interpreters/Aggregator.h | 2 +- dbms/src/TiDB/Collation/Collator.h | 36 +- 19 files changed, 244 insertions(+), 628 deletions(-) diff --git a/dbms/src/Columns/ColumnAggregateFunction.h b/dbms/src/Columns/ColumnAggregateFunction.h index e0b30771f9c..6e8c6563ccb 100644 --- a/dbms/src/Columns/ColumnAggregateFunction.h +++ b/dbms/src/Columns/ColumnAggregateFunction.h @@ -252,12 +252,6 @@ class ColumnAggregateFunction final : public COWPtrHelper & /* pos */, bool /* use_nt_align_buffer */) override - { - throw Exception( - "Method deserializeForCmpAndInsertFromPos is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); - } void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & /* pos */, @@ -268,15 +262,6 @@ class ColumnAggregateFunction final : public COWPtrHelper & /* pos */, - const IColumn::Offsets & /* array_offsets */, - bool /* use_nt_align_buffer */) override - { - throw Exception( - "Method deserializeForCmpAndInsertFromPosColumnArray is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); - } void flushNTAlignBuffer() override { diff --git a/dbms/src/Columns/ColumnArray.cpp b/dbms/src/Columns/ColumnArray.cpp index 98ffc303199..b9fe2f7f784 100644 --- a/dbms/src/Columns/ColumnArray.cpp +++ b/dbms/src/Columns/ColumnArray.cpp @@ -376,18 +376,7 @@ void ColumnArray::serializeToPosImpl( getData().serializeToPosForColumnArray(pos, start, length, has_null, getOffsets()); } -void ColumnArray::deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) -{ - deserializeAndInsertFromPosImpl(pos, use_nt_align_buffer); -} - void ColumnArray::deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) -{ - deserializeAndInsertFromPosImpl(pos, use_nt_align_buffer); -} - -template -void ColumnArray::deserializeAndInsertFromPosImpl(PaddedPODArray & pos, bool use_nt_align_buffer) { auto & offsets = getOffsets(); size_t prev_size = offsets.size(); @@ -402,10 +391,7 @@ void ColumnArray::deserializeAndInsertFromPosImpl(PaddedPODArray & pos, pos[i] += sizeof(UInt32); } - if constexpr (compare_semantics) - getData().deserializeForCmpAndInsertFromPosColumnArray(pos, offsets, use_nt_align_buffer); - else - getData().deserializeAndInsertFromPosForColumnArray(pos, offsets, use_nt_align_buffer); + getData().deserializeAndInsertFromPosForColumnArray(pos, offsets, use_nt_align_buffer); } void ColumnArray::flushNTAlignBuffer() diff --git a/dbms/src/Columns/ColumnArray.h b/dbms/src/Columns/ColumnArray.h index 572523ce14a..b9f50ba62fe 100644 --- a/dbms/src/Columns/ColumnArray.h +++ b/dbms/src/Columns/ColumnArray.h @@ -162,7 +162,6 @@ class ColumnArray final : public COWPtrHelper } void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; - void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & /* pos */, @@ -173,15 +172,6 @@ class ColumnArray final : public COWPtrHelper "Method deserializeAndInsertFromPosForColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & /* pos */, - const IColumn::Offsets & /* array_offsets */, - bool /* use_nt_align_buffer */) override - { - throw Exception( - "Method deserializeForCmpAndInsertFromPosColumnArray is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); - } void flushNTAlignBuffer() override; diff --git a/dbms/src/Columns/ColumnConst.h b/dbms/src/Columns/ColumnConst.h index 41d44c8ab89..0702bebe5ce 100644 --- a/dbms/src/Columns/ColumnConst.h +++ b/dbms/src/Columns/ColumnConst.h @@ -194,12 +194,6 @@ class ColumnConst final : public COWPtrHelper "Method deserializeAndInsertFromPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeForCmpAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) override - { - throw Exception( - "Method deserializeForCmpAndInsertFromPos is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); - } void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & /* pos */, @@ -210,15 +204,6 @@ class ColumnConst final : public COWPtrHelper "Method deserializeAndInsertFromPosForColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & /* pos */, - const IColumn::Offsets & /* array_offsets */, - bool /* use_nt_align_buffer */) override - { - throw Exception( - "Method deserializeForCmpAndInsertFromPosColumnArray is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); - } void flushNTAlignBuffer() override { diff --git a/dbms/src/Columns/ColumnDecimal.cpp b/dbms/src/Columns/ColumnDecimal.cpp index fdbc93fada0..4067e3a049d 100644 --- a/dbms/src/Columns/ColumnDecimal.cpp +++ b/dbms/src/Columns/ColumnDecimal.cpp @@ -305,9 +305,14 @@ void ColumnDecimal::serializeToPosForCmpColumnArray( const TiDB::TiDBCollatorPtr &, String *) const { -#define CALL(has_null, has_nullmap) \ - { \ - serializeToPosForColumnArrayImpl(pos, start, length, array_offsets, nullmap); \ +#define CALL(has_null, has_nullmap) \ + { \ + serializeToPosForColumnArrayImpl( \ + pos, \ + start, \ + length, \ + array_offsets, \ + nullmap); \ } if (has_null) @@ -391,14 +396,16 @@ void ColumnDecimal::serializeToPosForColumnArrayImpl( } size_t len = array_offsets[start + i] - array_offsets[start + i - 1]; + auto start_idx = array_offsets[start + i - 1]; if constexpr (compare_semantics && is_Decimal256) { auto * p = pos[i]; for (size_t j = 0; j < len; ++j) { - // Clear the value and only set the necessary parts for compare semantic + // Clear the value and only set the necessary parts for compare semantics memset(static_cast(&tmp_val), 0, sizeof(T)); - tmp_val.value.backend().assign(data[start + i].value.backend()); + tmp_val.value.backend().assign(data[start_idx + j].value.backend()); + tiflash_compiler_builtin_memcpy(p, &tmp_val, sizeof(T)); p += sizeof(T); } @@ -406,7 +413,6 @@ void ColumnDecimal::serializeToPosForColumnArrayImpl( } else { - auto start_idx = array_offsets[start + i - 1]; if (len <= 4) { auto * p = pos[i]; @@ -521,12 +527,6 @@ void ColumnDecimal::deserializeAndInsertFromPos( } } -template -void ColumnDecimal::deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) -{ - deserializeAndInsertFromPos(pos, use_nt_align_buffer); -} - template void ColumnDecimal::deserializeAndInsertFromPosForColumnArray( PaddedPODArray & pos, @@ -573,15 +573,6 @@ void ColumnDecimal::deserializeAndInsertFromPosForColumnArray( } } -template -void ColumnDecimal::deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer) -{ - deserializeAndInsertFromPosForColumnArray(pos, array_offsets, use_nt_align_buffer); -} - template void ColumnDecimal::flushNTAlignBuffer() { diff --git a/dbms/src/Columns/ColumnDecimal.h b/dbms/src/Columns/ColumnDecimal.h index b64ca3d30c4..9de8ce0aeb7 100644 --- a/dbms/src/Columns/ColumnDecimal.h +++ b/dbms/src/Columns/ColumnDecimal.h @@ -211,16 +211,11 @@ class ColumnDecimal final : public COWPtrHelper & pos, bool use_nt_align_buffer) override; - void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, bool use_nt_align_buffer) override; - void deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer) override; void flushNTAlignBuffer() override; diff --git a/dbms/src/Columns/ColumnFixedString.h b/dbms/src/Columns/ColumnFixedString.h index a437c8c617b..aceb30467a9 100644 --- a/dbms/src/Columns/ColumnFixedString.h +++ b/dbms/src/Columns/ColumnFixedString.h @@ -177,22 +177,11 @@ class ColumnFixedString final : public COWPtrHelper String *) const override; void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; - void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override - { - deserializeAndInsertFromPos(pos, use_nt_align_buffer); - } void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, bool use_nt_align_buffer) override; - void deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer) override - { - deserializeAndInsertFromPosForColumnArray(pos, array_offsets, use_nt_align_buffer); - } void flushNTAlignBuffer() override {} diff --git a/dbms/src/Columns/ColumnFunction.h b/dbms/src/Columns/ColumnFunction.h index ef712883144..4a0b093ee13 100644 --- a/dbms/src/Columns/ColumnFunction.h +++ b/dbms/src/Columns/ColumnFunction.h @@ -205,12 +205,6 @@ class ColumnFunction final : public COWPtrHelper "Method deserializeAndInsertFromPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeForCmpAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) override - { - throw Exception( - "Method deserializeForCmpAndInsertFromPos is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); - } void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & /* pos */, @@ -221,15 +215,6 @@ class ColumnFunction final : public COWPtrHelper "Method deserializeAndInsertFromPosForColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & /* pos */, - const IColumn::Offsets & /* array_offsets */, - bool /* use_nt_align_buffer */) override - { - throw Exception( - "Method deserializeForCmpAndInsertFromPosColumnArray is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); - } void flushNTAlignBuffer() override { diff --git a/dbms/src/Columns/ColumnNullable.cpp b/dbms/src/Columns/ColumnNullable.cpp index 38a23280f13..9a2b640c5da 100644 --- a/dbms/src/Columns/ColumnNullable.cpp +++ b/dbms/src/Columns/ColumnNullable.cpp @@ -282,6 +282,12 @@ const char * ColumnNullable::deserializeAndInsertFromArena(const char * pos, con return pos; } +void ColumnNullable::countSerializeByteSize(PaddedPODArray & byte_size) const +{ + getNullMapColumn().countSerializeByteSize(byte_size); + getNestedColumn().countSerializeByteSize(byte_size); +} + void ColumnNullable::countSerializeByteSizeForCmp( PaddedPODArray & byte_size, const NullMap * nullmap, @@ -292,10 +298,13 @@ void ColumnNullable::countSerializeByteSizeForCmp( getNullMapColumn().countSerializeByteSizeForCmp(byte_size, nullptr, collator); getNestedColumn().countSerializeByteSizeForCmp(byte_size, &getNullMapData(), collator); } -void ColumnNullable::countSerializeByteSize(PaddedPODArray & byte_size) const + +void ColumnNullable::countSerializeByteSizeForColumnArray( + PaddedPODArray & byte_size, + const IColumn::Offsets & array_offsets) const { - getNullMapColumn().countSerializeByteSize(byte_size); - getNestedColumn().countSerializeByteSize(byte_size); + getNullMapColumn().countSerializeByteSizeForColumnArray(byte_size, array_offsets); + getNestedColumn().countSerializeByteSizeForColumnArray(byte_size, array_offsets); } void ColumnNullable::countSerializeByteSizeForCmpColumnArray( @@ -306,15 +315,15 @@ void ColumnNullable::countSerializeByteSizeForCmpColumnArray( { // Unable to handle ColumnArray(ColumnNullable(ColumnXXX)). throw Exception( - "countSerializeByteSizeForCmpColumnArray cannot handle ColumnArray(" + getName() + ")", + "countSerializeByteSizeForCmpColumnArray cannot handle ColumnArray(ColumnNullable(ColumnXXX))" + getName() + + ")", ErrorCodes::NOT_IMPLEMENTED); } -void ColumnNullable::countSerializeByteSizeForColumnArray( - PaddedPODArray & byte_size, - const IColumn::Offsets & array_offsets) const + +void ColumnNullable::serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const { - getNullMapColumn().countSerializeByteSizeForColumnArray(byte_size, array_offsets); - getNestedColumn().countSerializeByteSizeForColumnArray(byte_size, array_offsets); + getNullMapColumn().serializeToPos(pos, start, length, has_null); + getNestedColumn().serializeToPos(pos, start, length, has_null); } void ColumnNullable::serializeToPosForCmp( @@ -333,10 +342,15 @@ void ColumnNullable::serializeToPosForCmp( .serializeToPosForCmp(pos, start, length, has_null, &getNullMapData(), collator, sort_key_container); } -void ColumnNullable::serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const +void ColumnNullable::serializeToPosForColumnArray( + PaddedPODArray & pos, + size_t start, + size_t length, + bool has_null, + const IColumn::Offsets & array_offsets) const { - getNullMapColumn().serializeToPos(pos, start, length, has_null); - getNestedColumn().serializeToPos(pos, start, length, has_null); + getNullMapColumn().serializeToPosForColumnArray(pos, start, length, has_null, array_offsets); + getNestedColumn().serializeToPosForColumnArray(pos, start, length, has_null, array_offsets); } void ColumnNullable::serializeToPosForCmpColumnArray( @@ -353,39 +367,16 @@ void ColumnNullable::serializeToPosForCmpColumnArray( // while ColumnNullable::nullmap corresponds to the rows of ColumnNullable. // This means it's not easy to correctly serialize the row in ColumnNullable to the corresponding position in pos. throw Exception( - "serializeToPosForCmpColumnArray cannot handle ColumnArray(" + getName() + ")", + "serializeToPosForCmpColumnArray cannot handle ColumnArray(ColumnNullable(ColumnXXX))" + getName() + ")", ErrorCodes::NOT_IMPLEMENTED); } -void ColumnNullable::serializeToPosForColumnArray( - PaddedPODArray & pos, - size_t start, - size_t length, - bool has_null, - const IColumn::Offsets & array_offsets) const -{ - getNullMapColumn().serializeToPosForColumnArray(pos, start, length, has_null, array_offsets); - getNestedColumn().serializeToPosForColumnArray(pos, start, length, has_null, array_offsets); -} -void ColumnNullable::deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) -{ - getNullMapColumn().deserializeForCmpAndInsertFromPos(pos, use_nt_align_buffer); - getNestedColumn().deserializeForCmpAndInsertFromPos(pos, use_nt_align_buffer); -} void ColumnNullable::deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) { getNullMapColumn().deserializeAndInsertFromPos(pos, use_nt_align_buffer); getNestedColumn().deserializeAndInsertFromPos(pos, use_nt_align_buffer); } -void ColumnNullable::deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer) -{ - getNullMapColumn().deserializeForCmpAndInsertFromPosColumnArray(pos, array_offsets, use_nt_align_buffer); - getNestedColumn().deserializeForCmpAndInsertFromPosColumnArray(pos, array_offsets, use_nt_align_buffer); -} void ColumnNullable::deserializeAndInsertFromPosForColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, diff --git a/dbms/src/Columns/ColumnNullable.h b/dbms/src/Columns/ColumnNullable.h index d173221f0f1..15f8a515ccd 100644 --- a/dbms/src/Columns/ColumnNullable.h +++ b/dbms/src/Columns/ColumnNullable.h @@ -122,16 +122,11 @@ class ColumnNullable final : public COWPtrHelper String * sort_key_container) const override; void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; - void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, bool use_nt_align_buffer) override; - void deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer) override; void flushNTAlignBuffer() override; diff --git a/dbms/src/Columns/ColumnString.cpp b/dbms/src/Columns/ColumnString.cpp index c7349db4e18..a7ba9d3fbcc 100644 --- a/dbms/src/Columns/ColumnString.cpp +++ b/dbms/src/Columns/ColumnString.cpp @@ -35,17 +35,6 @@ extern const int PARAMETER_OUT_OF_BOUND; extern const int SIZES_OF_COLUMNS_DOESNT_MATCH; } // namespace ErrorCodes -struct ColumnStringDefaultValue -{ - char mem[sizeof(UInt32) + 1] = {0}; - ColumnStringDefaultValue() - { - UInt32 str_size = 1; - tiflash_compiler_builtin_memcpy(&mem[0], &str_size, sizeof(str_size)); - mem[sizeof(UInt32)] = 0; - } -}; - MutableColumnPtr ColumnString::cloneResized(size_t to_size) const { auto res = ColumnString::create(); @@ -493,6 +482,11 @@ void ColumnString::getPermutationWithCollationImpl( } } +void ColumnString::countSerializeByteSize(PaddedPODArray & byte_size) const +{ + countSerializeByteSizeImpl(byte_size, nullptr, nullptr); +} + void ColumnString::countSerializeByteSizeForCmp( PaddedPODArray & byte_size, const NullMap * nullmap, @@ -529,11 +523,6 @@ void ColumnString::countSerializeByteSizeForCmp( } } -void ColumnString::countSerializeByteSize(PaddedPODArray & byte_size) const -{ - countSerializeByteSizeImpl(byte_size, nullptr, nullptr); -} - template void ColumnString::countSerializeByteSizeImpl( PaddedPODArray & byte_size, @@ -586,6 +575,15 @@ void ColumnString::countSerializeByteSizeImpl( } } +void ColumnString::countSerializeByteSizeForColumnArray( + PaddedPODArray & byte_size, + const IColumn::Offsets & array_offsets) const +{ + countSerializeByteSizeForColumnArrayImpl< + /*need_decode_collator=*/false, + /*has_nullmap=*/false>(byte_size, array_offsets, nullptr, nullptr); +} + void ColumnString::countSerializeByteSizeForCmpColumnArray( PaddedPODArray & byte_size, const IColumn::Offsets & array_offsets, @@ -617,15 +615,6 @@ void ColumnString::countSerializeByteSizeForCmpColumnArray( } } -void ColumnString::countSerializeByteSizeForColumnArray( - PaddedPODArray & byte_size, - const IColumn::Offsets & array_offsets) const -{ - countSerializeByteSizeForColumnArrayImpl< - /*need_decode_collator=*/false, - /*has_nullmap=*/false>(byte_size, array_offsets, nullptr, nullptr); -} - template void ColumnString::countSerializeByteSizeForColumnArrayImpl( PaddedPODArray & byte_size, @@ -698,152 +687,35 @@ inline bool needDecodeCollatorForCmp(const TiDB::TiDBCollatorPtr & collator) return collator != nullptr && !collator->isTrivialCollator(); } -void ColumnString::serializeToPosForCmp( - PaddedPODArray & pos, - size_t start, - size_t length, - bool has_null, - const NullMap * nullmap, - const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container) const -{ - if (has_null) - { - if (needDecodeCollatorForCmp(collator)) - { - if (nullmap != nullptr) - { - serializeToPosImplType( - pos, - start, - length, - collator, - sort_key_container, - nullmap); - } - else - { - serializeToPosImplType( - pos, - start, - length, - collator, - sort_key_container, - nullptr); - } - } - else - { - if (nullmap != nullptr) - { - serializeToPosImplType( - pos, - start, - length, - nullptr, - nullptr, - nullmap); - } - else - { - serializeToPosImplType( - pos, - start, - length, - nullptr, - nullptr, - nullptr); - } - } - } - else - { - if (needDecodeCollatorForCmp(collator)) - { - if (nullmap != nullptr) - { - serializeToPosImplType( - pos, - start, - length, - collator, - sort_key_container, - nullmap); - } - else - { - serializeToPosImplType( - pos, - start, - length, - collator, - sort_key_container, - nullptr); - } - } - else - { - if (nullmap != nullptr) - { - serializeToPosImplType( - pos, - start, - length, - nullptr, - nullptr, - nullmap); - } - else - { - serializeToPosImplType( - pos, - start, - length, - nullptr, - nullptr, - nullptr); - } - } - } -} - void ColumnString::serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const { if (has_null) - serializeToPosImplType( - pos, - start, - length, - nullptr, - nullptr, - nullptr); + serializeToPosImpl< + /*has_null=*/true, + /*need_decode_collator=*/false, + TiDB::ITiDBCollator, + /*has_nullmap=*/false>(pos, start, length, nullptr, nullptr, nullptr); else - serializeToPosImplType( - pos, - start, - length, - nullptr, - nullptr, - nullptr); + serializeToPosImpl< + /*has_null=*/false, + /*need_decode_collator=*/false, + TiDB::ITiDBCollator, + /*has_nullmap=*/false>(pos, start, length, nullptr, nullptr, nullptr); } -template -void ColumnString::serializeToPosImplType( +void ColumnString::serializeToPosForCmp( PaddedPODArray & pos, size_t start, size_t length, + bool has_null, + const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container, - const NullMap * nullmap) const + String * sort_key_container) const { - if constexpr (need_decode_collator) - { - RUNTIME_CHECK(collator && sort_key_container); - -#define M(VAR_PREFIX, COLLATOR_NAME, IMPL_TYPE, COLLATOR_ID) \ +#define M(VAR_PREFIX, COLLATOR_NAME, IMPL_TYPE, COLLATOR_ID, has_null, has_nullmap) \ case (COLLATOR_ID): \ { \ - serializeToPosImpl( \ + serializeToPosImpl( \ pos, \ start, \ length, \ @@ -853,26 +725,50 @@ void ColumnString::serializeToPosImplType( break; \ } - switch (collator->getCollatorId()) - { - APPLY_FOR_COLLATOR_TYPES(M) - default: - { - throw Exception(fmt::format("unexpected collator: {}", collator->getCollatorId())); - } - }; -#undef M +#define CALL3(has_null, has_nullmap) \ + { \ + RUNTIME_CHECK(collator && sort_key_container); \ + switch (collator->getCollatorId()) \ + { \ + APPLY_FOR_COLLATOR_TYPES(M, has_null, has_nullmap) \ + default: \ + { \ + throw Exception(fmt::format("unexpected collator: {}", collator->getCollatorId())); \ + } \ + } \ } - else - { - serializeToPosImpl( - pos, - start, - length, - collator, - sort_key_container, - nullmap); + +#define CALL2(has_null, has_nullmap) \ + { \ + if (needDecodeCollatorForCmp(collator)) \ + CALL3(has_null, has_nullmap) \ + else \ + serializeToPosImpl( \ + pos, \ + start, \ + length, \ + collator, \ + sort_key_container, \ + nullmap); \ + } + +#define CALL1(has_null) \ + { \ + if (nullmap) \ + CALL2(has_null, true) \ + else \ + CALL2(has_null, false) \ } + + if (has_null) + CALL1(true) + else + CALL1(false) + +#undef CALL1 +#undef CALL2 +#undef CALL3 +#undef M } template @@ -890,8 +786,16 @@ void ColumnString::serializeToPosImpl( RUNTIME_CHECK(!has_nullmap || (nullmap && nullmap->size() == size())); /// To avoid virtual function call of sortKey(). - static const ColumnStringDefaultValue col_str_def_val; const auto * derived_collator = static_cast(collator); + + static const std::array default_val = [] { + std::array val{}; + UInt32 sz = 1; + tiflash_compiler_builtin_memcpy(val.data(), &sz, sizeof(UInt32)); + val[sizeof(UInt32)] = 0; + return val; + }(); + /// countSerializeByteSizeImpl has already checked that the size of one element is not greater than UINT32_MAX for (size_t i = 0; i < length; ++i) { @@ -905,8 +809,8 @@ void ColumnString::serializeToPosImpl( { if (DB::isNullAt(*nullmap, start + i)) { - tiflash_compiler_builtin_memcpy(pos[i], &col_str_def_val.mem[0], sizeof(col_str_def_val.mem)); - pos[i] += sizeof(col_str_def_val.mem); + tiflash_compiler_builtin_memcpy(pos[i], default_val.data(), sizeof(default_val)); + pos[i] += sizeof(default_val); continue; } } @@ -937,100 +841,6 @@ void ColumnString::serializeToPosImpl( } } -void ColumnString::serializeToPosForCmpColumnArray( - PaddedPODArray & pos, - size_t start, - size_t length, - bool has_null, - const NullMap * nullmap, - const IColumn::Offsets & array_offsets, - const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container) const -{ - if (has_null) - { - if (needDecodeCollatorForCmp(collator)) - { - if (nullmap != nullptr) - { - serializeToPosForColumnArrayImplType< - /*compare_semantics=*/true, - /*has_null=*/true, - /*need_decode_collator=*/true, - /*has_nullmap=*/true>(pos, start, length, array_offsets, collator, sort_key_container, nullmap); - } - else - { - serializeToPosForColumnArrayImplType< - /*compare_semantics=*/true, - /*has_null=*/true, - /*need_decode_collator=*/true, - /*has_nullmap=*/false>(pos, start, length, array_offsets, collator, sort_key_container, nullptr); - } - } - else - { - if (nullmap != nullptr) - { - serializeToPosForColumnArrayImplType< - /*compare_semantics=*/true, - /*has_null=*/true, - /*need_decode_collator=*/false, - /*has_nullmap=*/true>(pos, start, length, array_offsets, nullptr, nullptr, nullmap); - } - else - { - serializeToPosForColumnArrayImplType< - /*compare_semantics=*/true, - /*has_null=*/true, - /*need_decode_collator=*/false, - /*has_nullmap=*/false>(pos, start, length, array_offsets, nullptr, nullptr, nullptr); - } - } - } - else - { - if (needDecodeCollatorForCmp(collator)) - { - if (nullmap != nullptr) - { - serializeToPosForColumnArrayImplType< - /*compare_semantics=*/true, - /*has_null=*/false, - /*need_decode_collator=*/true, - /*has_nullmap=*/true>(pos, start, length, array_offsets, collator, sort_key_container, nullmap); - } - else - { - serializeToPosForColumnArrayImplType< - /*compare_semantics=*/true, - /*has_null=*/false, - /*need_decode_collator=*/true, - /*has_nullmap=*/false>(pos, start, length, array_offsets, collator, sort_key_container, nullptr); - } - } - else - { - if (nullmap != nullptr) - { - serializeToPosForColumnArrayImplType< - /*compare_semantics=*/true, - /*has_null=*/false, - /*need_decode_collator=*/false, - /*has_nullmap=*/true>(pos, start, length, array_offsets, nullptr, nullptr, nullmap); - } - else - { - serializeToPosForColumnArrayImplType< - /*compare_semantics=*/true, - /*has_null=*/false, - /*need_decode_collator=*/false, - /*has_nullmap=*/false>(pos, start, length, array_offsets, nullptr, nullptr, nullptr); - } - } - } -} - void ColumnString::serializeToPosForColumnArray( PaddedPODArray & pos, size_t start, @@ -1039,69 +849,91 @@ void ColumnString::serializeToPosForColumnArray( const IColumn::Offsets & array_offsets) const { if (has_null) - serializeToPosForColumnArrayImplType< - /*compare_semantics=*/false, + serializeToPosForColumnArrayImpl< /*has_null=*/true, /*need_decode_collator=*/false, + TiDB::ITiDBCollator, /*has_nullmap=*/false>(pos, start, length, array_offsets, nullptr, nullptr, nullptr); else - serializeToPosForColumnArrayImplType< - /*compare_semantics=*/false, + serializeToPosForColumnArrayImpl< /*has_null=*/false, /*need_decode_collator=*/false, + TiDB::ITiDBCollator, /*has_nullmap=*/false>(pos, start, length, array_offsets, nullptr, nullptr, nullptr); } -template -void ColumnString::serializeToPosForColumnArrayImplType( +void ColumnString::serializeToPosForCmpColumnArray( PaddedPODArray & pos, size_t start, size_t length, + bool has_null, + const NullMap * nullmap, const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container, - const NullMap * nullmap) const + String * sort_key_container) const { - if constexpr (need_decode_collator) - { - RUNTIME_CHECK(collator && sort_key_container); - -#define M(VAR_PREFIX, COLLATOR_NAME, IMPL_TYPE, COLLATOR_ID) \ - case (COLLATOR_ID): \ - { \ - serializeToPosForColumnArrayImpl( \ - pos, \ - start, \ - length, \ - array_offsets, \ - collator, \ - sort_key_container, \ - nullmap); \ - break; \ +#define M(VAR_PREFIX, COLLATOR_NAME, IMPL_TYPE, COLLATOR_ID, has_null, has_nullmap) \ + case (COLLATOR_ID): \ + { \ + serializeToPosForColumnArrayImpl( \ + pos, \ + start, \ + length, \ + array_offsets, \ + collator, \ + sort_key_container, \ + nullmap); \ + break; \ } - switch (collator->getCollatorId()) - { - APPLY_FOR_COLLATOR_TYPES(M) - default: - { - throw Exception(fmt::format("unexpected collator: {}", collator->getCollatorId())); - } - }; -#undef M +#define CALL3(has_null, has_nullmap) \ + { \ + RUNTIME_CHECK(collator && sort_key_container); \ + switch (collator->getCollatorId()) \ + { \ + APPLY_FOR_COLLATOR_TYPES(M, has_null, has_nullmap) \ + default: \ + { \ + throw Exception(fmt::format("unexpected collator: {}", collator->getCollatorId())); \ + } \ + } \ } - else - { - serializeToPosForColumnArrayImpl< - compare_semantics, - has_null, - need_decode_collator, - TiDB::ITiDBCollator, - has_nullmap>(pos, start, length, array_offsets, collator, sort_key_container, nullmap); + +#define CALL2(has_null, has_nullmap) \ + { \ + if (needDecodeCollatorForCmp(collator)) \ + CALL3(has_null, has_nullmap) \ + else \ + serializeToPosForColumnArrayImpl( \ + pos, \ + start, \ + length, \ + array_offsets, \ + collator, \ + sort_key_container, \ + nullmap); \ + } + +#define CALL1(has_null) \ + { \ + if (nullmap) \ + CALL2(has_null, true) \ + else \ + CALL2(has_null, false) \ } + + if (has_null) + CALL1(true) + else + CALL1(false) + +#undef CALL1 +#undef CALL2 +#undef CALL3 +#undef M } -template +template void ColumnString::serializeToPosForColumnArrayImpl( PaddedPODArray & pos, size_t start, @@ -1142,7 +974,7 @@ void ColumnString::serializeToPosForColumnArrayImpl( if (DB::isNullAt(*nullmap, start + i)) continue; } - if constexpr (compare_semantics && need_decode_collator) + if constexpr (need_decode_collator) { auto * size_pos = pos[i]; auto * p = pos[i]; @@ -1289,11 +1121,6 @@ void ColumnString::deserializeAndInsertFromPos(PaddedPODArray & pos, boo } } -void ColumnString::deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) -{ - deserializeAndInsertFromPos(pos, use_nt_align_buffer); -} - void ColumnString::deserializeAndInsertFromPosForColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, @@ -1335,14 +1162,6 @@ void ColumnString::deserializeAndInsertFromPosForColumnArray( } } -void ColumnString::deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer) -{ - deserializeAndInsertFromPosForColumnArrayImpl(pos, array_offsets, use_nt_align_buffer); -} - void ColumnString::flushNTAlignBuffer() { #ifdef TIFLASH_ENABLE_AVX_SUPPORT diff --git a/dbms/src/Columns/ColumnString.h b/dbms/src/Columns/ColumnString.h index 55c160236a8..f326dce8672 100644 --- a/dbms/src/Columns/ColumnString.h +++ b/dbms/src/Columns/ColumnString.h @@ -107,66 +107,6 @@ class ColumnString final : public COWPtrHelper } } - template - void countSerializeByteSizeImpl( - PaddedPODArray & byte_size, - const NullMap * nullmap, - const TiDB::TiDBCollatorPtr & collator) const; - template - void countSerializeByteSizeForColumnArrayImpl( - PaddedPODArray & byte_size, - const IColumn::Offsets & array_offsets, - const NullMap * nullmap, - const TiDB::TiDBCollatorPtr & collator) const; - - template - void serializeToPosImplType( - PaddedPODArray & pos, - size_t start, - size_t length, - const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container, - const NullMap * nullmap) const; - template - void serializeToPosImpl( - PaddedPODArray & pos, - size_t start, - size_t length, - const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container, - const NullMap * nullmap) const; - - template - void serializeToPosForColumnArrayImplType( - PaddedPODArray & pos, - size_t start, - size_t length, - const IColumn::Offsets & array_offsets, - const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container, - const NullMap * nullmap) const; - template < - bool compare_semantics, - bool has_null, - bool need_decode_collator, - typename DerivedCollator, - bool has_nullmap> - void serializeToPosForColumnArrayImpl( - PaddedPODArray & pos, - size_t start, - size_t length, - const IColumn::Offsets & array_offsets, - const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container, - const NullMap * nullmap) const; - - void deserializeAndInsertFromPosImpl(PaddedPODArray & pos, bool use_nt_align_buffer); - template - void deserializeAndInsertFromPosForColumnArrayImpl( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer); - public: const char * getFamilyName() const override { return "String"; } @@ -317,6 +257,11 @@ class ColumnString final : public COWPtrHelper PaddedPODArray & byte_size, const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator) const override; + template + void countSerializeByteSizeImpl( + PaddedPODArray & byte_size, + const NullMap * nullmap, + const TiDB::TiDBCollatorPtr & collator) const; void countSerializeByteSizeForColumnArray( PaddedPODArray & byte_size, @@ -326,6 +271,12 @@ class ColumnString final : public COWPtrHelper const IColumn::Offsets & array_offsets, const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator) const override; + template + void countSerializeByteSizeForColumnArrayImpl( + PaddedPODArray & byte_size, + const IColumn::Offsets & array_offsets, + const NullMap * nullmap, + const TiDB::TiDBCollatorPtr & collator) const; void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; void serializeToPosForCmp( @@ -336,6 +287,14 @@ class ColumnString final : public COWPtrHelper const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const override; + template + void serializeToPosImpl( + PaddedPODArray & pos, + size_t start, + size_t length, + const TiDB::TiDBCollatorPtr & collator, + String * sort_key_container, + const NullMap * nullmap) const; void serializeToPosForColumnArray( PaddedPODArray & pos, @@ -352,18 +311,22 @@ class ColumnString final : public COWPtrHelper const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const override; + template + void serializeToPosForColumnArrayImpl( + PaddedPODArray & pos, + size_t start, + size_t length, + const IColumn::Offsets & array_offsets, + const TiDB::TiDBCollatorPtr & collator, + String * sort_key_container, + const NullMap * nullmap) const; void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; - void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, bool use_nt_align_buffer) override; - void deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer) override; void flushNTAlignBuffer() override; diff --git a/dbms/src/Columns/ColumnTuple.h b/dbms/src/Columns/ColumnTuple.h index dac317a4f0d..cf993910a14 100644 --- a/dbms/src/Columns/ColumnTuple.h +++ b/dbms/src/Columns/ColumnTuple.h @@ -183,11 +183,6 @@ class ColumnTuple final : public COWPtrHelper for (auto & column : columns) column->assumeMutableRef().deserializeAndInsertFromPos(pos, use_nt_align_buffer); } - void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override - { - for (auto & column : columns) - column->assumeMutableRef().deserializeForCmpAndInsertFromPos(pos, use_nt_align_buffer); - } void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & pos, @@ -200,17 +195,6 @@ class ColumnTuple final : public COWPtrHelper array_offsets, use_nt_align_buffer); } - void deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer) override - { - for (auto & column : columns) - column->assumeMutableRef().deserializeForCmpAndInsertFromPosColumnArray( - pos, - array_offsets, - use_nt_align_buffer); - } void flushNTAlignBuffer() override { diff --git a/dbms/src/Columns/ColumnVector.h b/dbms/src/Columns/ColumnVector.h index 096ad6b9dbf..17f6de6747d 100644 --- a/dbms/src/Columns/ColumnVector.h +++ b/dbms/src/Columns/ColumnVector.h @@ -380,22 +380,11 @@ class ColumnVector final : public COWPtrHelper & pos, bool use_nt_align_buffer) override; - void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override - { - deserializeAndInsertFromPos(pos, use_nt_align_buffer); - } void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, bool use_nt_align_buffer) override; - void deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & pos, - const IColumn::Offsets & array_offsets, - bool use_nt_align_buffer) override - { - deserializeAndInsertFromPosForColumnArray(pos, array_offsets, use_nt_align_buffer); - } void flushNTAlignBuffer() override; diff --git a/dbms/src/Columns/IColumn.h b/dbms/src/Columns/IColumn.h index 13450adadee..1f77511bbcc 100644 --- a/dbms/src/Columns/IColumn.h +++ b/dbms/src/Columns/IColumn.h @@ -341,8 +341,6 @@ class IColumn : public COWPtr /// for (auto & column_ptr : mutable_columns) /// column_ptr->flushNTAlignBuffer(); virtual void deserializeAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) = 0; - virtual void deserializeForCmpAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) - = 0; /// Deserialize and insert data from pos and forward each pos[i] to the end of serialized data. /// Only called by ColumnArray. @@ -353,11 +351,6 @@ class IColumn : public COWPtr const Offsets & /* array_offsets */, bool /* use_nt_align_buffer */) = 0; - virtual void deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & /* pos */, - const Offsets & /* array_offsets */, - bool /* use_nt_align_buffer */) - = 0; virtual void flushNTAlignBuffer() = 0; diff --git a/dbms/src/Columns/IColumnDummy.h b/dbms/src/Columns/IColumnDummy.h index 9ead751cb2e..ac45ff60580 100644 --- a/dbms/src/Columns/IColumnDummy.h +++ b/dbms/src/Columns/IColumnDummy.h @@ -173,12 +173,6 @@ class IColumnDummy : public IColumn "Method deserializeAndInsertFromPos is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeForCmpAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) override - { - throw Exception( - "Method deserializeForCmpAndInsertFromPos is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); - } void deserializeAndInsertFromPosForColumnArray( PaddedPODArray & /* pos */, @@ -189,15 +183,6 @@ class IColumnDummy : public IColumn "Method deserializeAndInsertFromPosForColumnArray is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void deserializeForCmpAndInsertFromPosColumnArray( - PaddedPODArray & /* pos */, - const IColumn::Offsets & /* array_offsets */, - bool /* use_nt_align_buffer */) override - { - throw Exception( - "Method deserializeForCmpAndInsertFromPosColumnArray is not supported for " + getName(), - ErrorCodes::NOT_IMPLEMENTED); - } void flushNTAlignBuffer() override { diff --git a/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp b/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp index e3009589752..fe5f652328e 100644 --- a/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp +++ b/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp @@ -393,7 +393,7 @@ class TestColumnSerializeDeserialize : public ::testing::Test auto new_col_ptr = column_ptr->cloneEmpty(); if (use_nt_align_buffer) new_col_ptr->reserveAlign(byte_size.size(), FULL_VECTOR_SIZE_AVX2); - new_col_ptr->deserializeForCmpAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); current_size = 0; pos.clear(); @@ -414,7 +414,7 @@ class TestColumnSerializeDeserialize : public ::testing::Test collator, sort_key_container); - new_col_ptr->deserializeForCmpAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); current_size = 0; pos.clear(); @@ -428,7 +428,7 @@ class TestColumnSerializeDeserialize : public ::testing::Test ori_pos.push_back(ptr); column_ptr->serializeToPosForCmp(pos, 0, byte_size.size(), false, nullptr, collator, sort_key_container); - new_col_ptr->deserializeForCmpAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); if (use_nt_align_buffer) new_col_ptr->flushNTAlignBuffer(); @@ -642,18 +642,7 @@ try .column; testCountSerializeByteSize(col_nullable_decimal_0, {49, 49, 49, 49, 49}); testSerializeAndDeserialize(col_nullable_decimal_0); - // nullable + bool + size_t + n * 8 - testCountSerializeByteSize( - col_nullable_decimal_0, - { - 1 + 1 + 8 + 1 * 8, - 1 + 1 + 8 + 2 * 8, - 1 + 1 + 8 + 1 * 8, - 1 + 1 + 8 + 1 * 8, - 1 + 1 + 8 + 2 * 8, - }, - true, - nullptr); + testCountSerializeByteSize(col_nullable_decimal_0, {49, 49, 49, 49, 49}, true, nullptr); testSerializeAndDeserialize(col_nullable_decimal_0, true, nullptr, nullptr); } @@ -742,13 +731,7 @@ try auto col_nullable_array_dec = ColumnNullable::create(col_array_dec, createColumn({1, 0, 1}).column); testCountSerializeByteSize(col_nullable_array_dec, {1 + 4 + 48, 1 + 4 + 48 * 2, 1 + 4 + 48 * 3}); testSerializeAndDeserialize(col_nullable_array_dec); - // 100.1111111111: (1 + 8 + 2 * 8) - // -11111111111111111111: (1 + 8 + 3 * 8) - testCountSerializeByteSize( - col_nullable_array_dec, - {1 + 4, 1 + 4 + (1 + 8 + 2 * 8) + (1 + 8 + 3 * 8), 1 + 4}, - true, - nullptr); + testCountSerializeByteSize(col_nullable_array_dec, {1 + 4, 1 + 4 + 48 * 2, 1 + 4}, true, nullptr); testSerializeAndDeserialize(col_nullable_array_dec, true, nullptr, nullptr); } diff --git a/dbms/src/Interpreters/Aggregator.h b/dbms/src/Interpreters/Aggregator.h index 6c47dafd8b0..e9480c0e895 100644 --- a/dbms/src/Interpreters/Aggregator.h +++ b/dbms/src/Interpreters/Aggregator.h @@ -428,7 +428,7 @@ struct AggregationMethodSerialized static void insertKeyIntoColumnsBatch(PaddedPODArray & key_places, std::vector & key_columns) { for (auto * key_column : key_columns) - key_column->deserializeForCmpAndInsertFromPos(key_places, false); + key_column->deserializeAndInsertFromPos(key_places, false); } }; diff --git a/dbms/src/TiDB/Collation/Collator.h b/dbms/src/TiDB/Collation/Collator.h index 313d7152472..4f3efb5a86e 100644 --- a/dbms/src/TiDB/Collation/Collator.h +++ b/dbms/src/TiDB/Collation/Collator.h @@ -413,17 +413,25 @@ using BIN_COLLATOR_PADDING = BinCollator; using BIN_COLLATOR_NON_PADDING = BinCollator; } // namespace TiDB -#define APPLY_FOR_COLLATOR_TYPES_WITH_VARS(VAR_PREFIX, M) \ - M(VAR_PREFIX, utf8_general_ci, TiDB::GeneralCICollator, TiDB::ITiDBCollator::UTF8_GENERAL_CI) \ - M(VAR_PREFIX, utf8mb4_general_ci, TiDB::GeneralCICollator, TiDB::ITiDBCollator::UTF8MB4_GENERAL_CI) \ - M(VAR_PREFIX, utf8_unicode_ci, TiDB::UCACI_0400_PADDING, TiDB::ITiDBCollator::UTF8_UNICODE_CI) \ - M(VAR_PREFIX, utf8mb4_unicode_ci, TiDB::UCACI_0400_PADDING, TiDB::ITiDBCollator::UTF8MB4_UNICODE_CI) \ - M(VAR_PREFIX, utf8mb4_0900_ai_ci, TiDB::UCACI_0900_NON_PADDING, TiDB::ITiDBCollator::UTF8MB4_0900_AI_CI) \ - M(VAR_PREFIX, utf8mb4_0900_bin, TiDB::UTF8MB4_0900_BIN_TYPE, TiDB::ITiDBCollator::UTF8MB4_0900_BIN) \ - M(VAR_PREFIX, utf8mb4_bin, TiDB::UTF8MB4_BIN_TYPE, TiDB::ITiDBCollator::UTF8MB4_BIN) \ - M(VAR_PREFIX, latin1_bin, TiDB::BIN_COLLATOR_PADDING, TiDB::ITiDBCollator::LATIN1_BIN) \ - M(VAR_PREFIX, binary, TiDB::BIN_COLLATOR_NON_PADDING, TiDB::ITiDBCollator::BINARY) \ - M(VAR_PREFIX, ascii_bin, TiDB::BIN_COLLATOR_PADDING, TiDB::ITiDBCollator::ASCII_BIN) \ - M(VAR_PREFIX, utf8_bin, TiDB::UTF8MB4_BIN_TYPE, TiDB::ITiDBCollator::UTF8_BIN) - -#define APPLY_FOR_COLLATOR_TYPES(M) APPLY_FOR_COLLATOR_TYPES_WITH_VARS(tmp_, M) +#define APPLY_FOR_COLLATOR_TYPES_WITH_VARS(VAR_PREFIX, M, ...) \ + M(VAR_PREFIX, utf8_general_ci, TiDB::GeneralCICollator, TiDB::ITiDBCollator::UTF8_GENERAL_CI, ##__VA_ARGS__) \ + M(VAR_PREFIX, utf8mb4_general_ci, TiDB::GeneralCICollator, TiDB::ITiDBCollator::UTF8MB4_GENERAL_CI, ##__VA_ARGS__) \ + M(VAR_PREFIX, utf8_unicode_ci, TiDB::UCACI_0400_PADDING, TiDB::ITiDBCollator::UTF8_UNICODE_CI, ##__VA_ARGS__) \ + M(VAR_PREFIX, \ + utf8mb4_unicode_ci, \ + TiDB::UCACI_0400_PADDING, \ + TiDB::ITiDBCollator::UTF8MB4_UNICODE_CI, \ + ##__VA_ARGS__) \ + M(VAR_PREFIX, \ + utf8mb4_0900_ai_ci, \ + TiDB::UCACI_0900_NON_PADDING, \ + TiDB::ITiDBCollator::UTF8MB4_0900_AI_CI, \ + ##__VA_ARGS__) \ + M(VAR_PREFIX, utf8mb4_0900_bin, TiDB::UTF8MB4_0900_BIN_TYPE, TiDB::ITiDBCollator::UTF8MB4_0900_BIN, ##__VA_ARGS__) \ + M(VAR_PREFIX, utf8mb4_bin, TiDB::UTF8MB4_BIN_TYPE, TiDB::ITiDBCollator::UTF8MB4_BIN, ##__VA_ARGS__) \ + M(VAR_PREFIX, latin1_bin, TiDB::BIN_COLLATOR_PADDING, TiDB::ITiDBCollator::LATIN1_BIN, ##__VA_ARGS__) \ + M(VAR_PREFIX, binary, TiDB::BIN_COLLATOR_NON_PADDING, TiDB::ITiDBCollator::BINARY, ##__VA_ARGS__) \ + M(VAR_PREFIX, ascii_bin, TiDB::BIN_COLLATOR_PADDING, TiDB::ITiDBCollator::ASCII_BIN, ##__VA_ARGS__) \ + M(VAR_PREFIX, utf8_bin, TiDB::UTF8MB4_BIN_TYPE, TiDB::ITiDBCollator::UTF8_BIN, ##__VA_ARGS__) + +#define APPLY_FOR_COLLATOR_TYPES(M, ...) APPLY_FOR_COLLATOR_TYPES_WITH_VARS(tmp_, M, ##__VA_ARGS__) From 10b7e51d69283c0ea16056300d683ad082d0aa9e Mon Sep 17 00:00:00 2001 From: gengliqi Date: Wed, 26 Mar 2025 00:58:11 +0800 Subject: [PATCH 10/84] u Signed-off-by: gengliqi --- dbms/src/Columns/ColumnArray.h | 31 +++++------ dbms/src/Columns/ColumnDecimal.cpp | 4 +- dbms/src/Columns/ColumnFixedString.cpp | 71 +++++++++++++++----------- dbms/src/Columns/ColumnFixedString.h | 50 ++++++------------ dbms/src/Columns/ColumnVector.cpp | 4 +- dbms/src/Columns/IColumn.h | 1 + 6 files changed, 77 insertions(+), 84 deletions(-) diff --git a/dbms/src/Columns/ColumnArray.h b/dbms/src/Columns/ColumnArray.h index b9f50ba62fe..b5b2c6ee9b2 100644 --- a/dbms/src/Columns/ColumnArray.h +++ b/dbms/src/Columns/ColumnArray.h @@ -44,24 +44,6 @@ class ColumnArray final : public COWPtrHelper ColumnArray(const ColumnArray &) = default; - template - void countSerializeByteSizeImpl( - PaddedPODArray & byte_size, - const NullMap * nullmap, - const TiDB::TiDBCollatorPtr & collator) const; - - template - void serializeToPosImpl( - PaddedPODArray & pos, - size_t start, - size_t length, - const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container, - const NullMap * nullmap) const; - - template - void deserializeAndInsertFromPosImpl(PaddedPODArray & pos, bool use_nt_align_buffer); - public: /** Create immutable column using immutable arguments. This arguments may be shared with other columns. * Use IColumn::mutate in order to make mutable column and mutate shared nested columns. @@ -105,6 +87,11 @@ class ColumnArray final : public COWPtrHelper PaddedPODArray & byte_size, const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator) const override; + template + void countSerializeByteSizeImpl( + PaddedPODArray & byte_size, + const NullMap * nullmap, + const TiDB::TiDBCollatorPtr & collator) const; void countSerializeByteSizeForColumnArray( PaddedPODArray & /* byte_size */, @@ -134,6 +121,14 @@ class ColumnArray final : public COWPtrHelper const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const override; + template + void serializeToPosImpl( + PaddedPODArray & pos, + size_t start, + size_t length, + const TiDB::TiDBCollatorPtr & collator, + String * sort_key_container, + const NullMap * nullmap) const; void serializeToPosForColumnArray( PaddedPODArray & /* pos */, diff --git a/dbms/src/Columns/ColumnDecimal.cpp b/dbms/src/Columns/ColumnDecimal.cpp index 4067e3a049d..356c75937fa 100644 --- a/dbms/src/Columns/ColumnDecimal.cpp +++ b/dbms/src/Columns/ColumnDecimal.cpp @@ -396,7 +396,7 @@ void ColumnDecimal::serializeToPosForColumnArrayImpl( } size_t len = array_offsets[start + i] - array_offsets[start + i - 1]; - auto start_idx = array_offsets[start + i - 1]; + size_t start_idx = array_offsets[start + i - 1]; if constexpr (compare_semantics && is_Decimal256) { auto * p = pos[i]; @@ -554,7 +554,7 @@ void ColumnDecimal::deserializeAndInsertFromPosForColumnArray( for (size_t i = 0; i < size; ++i) { size_t len = array_offsets[start_point + i] - array_offsets[start_point + i - 1]; - auto start_idx = array_offsets[start_point + i - 1]; + size_t start_idx = array_offsets[start_point + i - 1]; if (len <= 4) { auto * p = pos[i]; diff --git a/dbms/src/Columns/ColumnFixedString.cpp b/dbms/src/Columns/ColumnFixedString.cpp index 760e5aa6d26..949ad243785 100644 --- a/dbms/src/Columns/ColumnFixedString.cpp +++ b/dbms/src/Columns/ColumnFixedString.cpp @@ -138,7 +138,7 @@ const char * ColumnFixedString::deserializeAndInsertFromArena(const char * pos, return pos + n; } -void ColumnFixedString::countSerializeByteSizeImpl(PaddedPODArray & byte_size) const +void ColumnFixedString::countSerializeByteSize(PaddedPODArray & byte_size) const { RUNTIME_CHECK_MSG(byte_size.size() == size(), "size of byte_size({}) != column size({})", byte_size.size(), size()); @@ -147,6 +147,26 @@ void ColumnFixedString::countSerializeByteSizeImpl(PaddedPODArray & byte byte_size[i] += n; } +void ColumnFixedString::countSerializeByteSizeForCmp( + PaddedPODArray & byte_size, + const NullMap * /*nullmap*/, + const TiDB::TiDBCollatorPtr & collator) const +{ + // collator->sortKey() will change the string length, which may exceeds n. + RUNTIME_CHECK_MSG( + !collator, + "{} doesn't support countSerializeByteSizeForCmp when collator is not null", + getName()); + countSerializeByteSize(byte_size); +} + +void ColumnFixedString::countSerializeByteSizeForColumnArray( + PaddedPODArray & byte_size, + const IColumn::Offsets & array_offsets) const +{ + countSerializeByteSizeForColumnArrayImpl(byte_size, array_offsets, nullptr); +} + void ColumnFixedString::countSerializeByteSizeForCmpColumnArray( PaddedPODArray & byte_size, const IColumn::Offsets & array_offsets, @@ -163,13 +183,6 @@ void ColumnFixedString::countSerializeByteSizeForCmpColumnArray( countSerializeByteSizeForColumnArrayImpl(byte_size, array_offsets, nullptr); } -void ColumnFixedString::countSerializeByteSizeForColumnArray( - PaddedPODArray & byte_size, - const IColumn::Offsets & array_offsets) const -{ - countSerializeByteSizeForColumnArrayImpl(byte_size, array_offsets, nullptr); -} - template void ColumnFixedString::countSerializeByteSizeForColumnArrayImpl( PaddedPODArray & byte_size, @@ -194,6 +207,14 @@ void ColumnFixedString::countSerializeByteSizeForColumnArrayImpl( } } +void ColumnFixedString::serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const +{ + if (has_null) + serializeToPosImpl(pos, start, length, nullptr); + else + serializeToPosImpl(pos, start, length, nullptr); +} + void ColumnFixedString::serializeToPosForCmp( PaddedPODArray & pos, size_t start, @@ -220,14 +241,6 @@ void ColumnFixedString::serializeToPosForCmp( } } -void ColumnFixedString::serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const -{ - if (has_null) - serializeToPosImpl(pos, start, length, nullptr); - else - serializeToPosImpl(pos, start, length, nullptr); -} - template void ColumnFixedString::serializeToPosImpl( PaddedPODArray & pos, @@ -261,6 +274,19 @@ void ColumnFixedString::serializeToPosImpl( } } +void ColumnFixedString::serializeToPosForColumnArray( + PaddedPODArray & pos, + size_t start, + size_t length, + bool has_null, + const IColumn::Offsets & array_offsets) const +{ + if (has_null) + serializeToPosForColumnArrayImpl(pos, start, length, array_offsets, nullptr); + else + serializeToPosForColumnArrayImpl(pos, start, length, array_offsets, nullptr); +} + void ColumnFixedString::serializeToPosForCmpColumnArray( PaddedPODArray & pos, size_t start, @@ -291,19 +317,6 @@ void ColumnFixedString::serializeToPosForCmpColumnArray( } } -void ColumnFixedString::serializeToPosForColumnArray( - PaddedPODArray & pos, - size_t start, - size_t length, - bool has_null, - const IColumn::Offsets & array_offsets) const -{ - if (has_null) - serializeToPosForColumnArrayImpl(pos, start, length, array_offsets, nullptr); - else - serializeToPosForColumnArrayImpl(pos, start, length, array_offsets, nullptr); -} - template void ColumnFixedString::serializeToPosForColumnArrayImpl( PaddedPODArray & pos, diff --git a/dbms/src/Columns/ColumnFixedString.h b/dbms/src/Columns/ColumnFixedString.h index aceb30467a9..736f3b79d2e 100644 --- a/dbms/src/Columns/ColumnFixedString.h +++ b/dbms/src/Columns/ColumnFixedString.h @@ -54,25 +54,6 @@ class ColumnFixedString final : public COWPtrHelper , chars(src.chars.begin(), src.chars.end()) , n(src.n){}; - void countSerializeByteSizeImpl(PaddedPODArray & byte_size) const; - - template - void countSerializeByteSizeForColumnArrayImpl( - PaddedPODArray & byte_size, - const IColumn::Offsets & array_offsets, - const NullMap * nullmap) const; - - template - void serializeToPosImpl(PaddedPODArray & pos, size_t start, size_t length, const NullMap * nullmap) const; - - template - void serializeToPosForColumnArrayImpl( - PaddedPODArray & pos, - size_t start, - size_t length, - const IColumn::Offsets & array_offsets, - const NullMap * nullmap) const; - public: std::string getName() const override { return "FixedString(" + std::to_string(n) + ")"; } const char * getFamilyName() const override { return "FixedString"; } @@ -124,22 +105,11 @@ class ColumnFixedString final : public COWPtrHelper const char * deserializeAndInsertFromArena(const char * pos, const TiDB::TiDBCollatorPtr &) override; - void countSerializeByteSize(PaddedPODArray & byte_size) const override - { - countSerializeByteSizeImpl(byte_size); - } + void countSerializeByteSize(PaddedPODArray & byte_size) const override; void countSerializeByteSizeForCmp( PaddedPODArray & byte_size, - const NullMap * /*nullmap*/, - const TiDB::TiDBCollatorPtr & collator) const override - { - // collator->sortKey() will change the string length, which may exceeds n. - RUNTIME_CHECK_MSG( - !collator, - "{} doesn't support countSerializeByteSizeForCmp when collator is not null", - getName()); - countSerializeByteSizeImpl(byte_size); - } + const NullMap * nullmap, + const TiDB::TiDBCollatorPtr & collator) const override; void countSerializeByteSizeForColumnArray( PaddedPODArray & byte_size, @@ -149,6 +119,11 @@ class ColumnFixedString final : public COWPtrHelper const IColumn::Offsets & array_offsets, const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator) const override; + template + void countSerializeByteSizeForColumnArrayImpl( + PaddedPODArray & byte_size, + const IColumn::Offsets & array_offsets, + const NullMap * nullmap) const; void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; void serializeToPosForCmp( @@ -159,6 +134,8 @@ class ColumnFixedString final : public COWPtrHelper const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, String *) const override; + template + void serializeToPosImpl(PaddedPODArray & pos, size_t start, size_t length, const NullMap * nullmap) const; void serializeToPosForColumnArray( PaddedPODArray & pos, @@ -175,6 +152,13 @@ class ColumnFixedString final : public COWPtrHelper const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator, String *) const override; + template + void serializeToPosForColumnArrayImpl( + PaddedPODArray & pos, + size_t start, + size_t length, + const IColumn::Offsets & array_offsets, + const NullMap * nullmap) const; void deserializeAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override; diff --git a/dbms/src/Columns/ColumnVector.cpp b/dbms/src/Columns/ColumnVector.cpp index fc39ae1c852..7bd13bbe8ec 100644 --- a/dbms/src/Columns/ColumnVector.cpp +++ b/dbms/src/Columns/ColumnVector.cpp @@ -262,7 +262,7 @@ void ColumnVector::serializeToPosForColumnArrayImpl( continue; } size_t len = array_offsets[start + i] - array_offsets[start + i - 1]; - auto start_idx = array_offsets[start + i - 1]; + size_t start_idx = array_offsets[start + i - 1]; if (len <= 4) { auto * p = pos[i]; @@ -403,7 +403,7 @@ void ColumnVector::deserializeAndInsertFromPosForColumnArray( for (size_t i = 0; i < size; ++i) { size_t len = array_offsets[start_point + i] - array_offsets[start_point + i - 1]; - auto start_idx = array_offsets[start_point + i - 1]; + size_t start_idx = array_offsets[start_point + i - 1]; if (len <= 4) { auto * p = pos[i]; diff --git a/dbms/src/Columns/IColumn.h b/dbms/src/Columns/IColumn.h index 1f77511bbcc..92c49ef86f4 100644 --- a/dbms/src/Columns/IColumn.h +++ b/dbms/src/Columns/IColumn.h @@ -352,6 +352,7 @@ class IColumn : public COWPtr bool /* use_nt_align_buffer */) = 0; + /// Flush the non-temporal align buffer if any. virtual void flushNTAlignBuffer() = 0; /// Update state of hash function with value of n-th element. From 5af75189217a3f44a0fad6d370696729b8e04cf7 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Wed, 26 Mar 2025 19:50:32 +0800 Subject: [PATCH 11/84] add deserializeAndAdvancePos interface Signed-off-by: gengliqi --- dbms/src/Columns/ColumnAggregateFunction.h | 16 ++ dbms/src/Columns/ColumnArray.cpp | 26 +++ dbms/src/Columns/ColumnArray.h | 11 ++ dbms/src/Columns/ColumnConst.h | 16 ++ dbms/src/Columns/ColumnDecimal.cpp | 21 +++ dbms/src/Columns/ColumnDecimal.h | 8 + dbms/src/Columns/ColumnFixedString.cpp | 17 ++ dbms/src/Columns/ColumnFixedString.h | 5 + dbms/src/Columns/ColumnFunction.h | 16 ++ dbms/src/Columns/ColumnNullable.cpp | 14 ++ dbms/src/Columns/ColumnNullable.h | 5 + dbms/src/Columns/ColumnString.cpp | 35 ++++ dbms/src/Columns/ColumnString.h | 33 ++-- dbms/src/Columns/ColumnTuple.h | 13 ++ dbms/src/Columns/ColumnVector.cpp | 21 +++ dbms/src/Columns/ColumnVector.h | 8 + dbms/src/Columns/IColumn.h | 23 ++- dbms/src/Columns/IColumnDummy.h | 16 ++ .../gtest_column_serialize_deserialize.cpp | 151 +++++++++++------- 19 files changed, 371 insertions(+), 84 deletions(-) diff --git a/dbms/src/Columns/ColumnAggregateFunction.h b/dbms/src/Columns/ColumnAggregateFunction.h index 6e8c6563ccb..a3a688db282 100644 --- a/dbms/src/Columns/ColumnAggregateFunction.h +++ b/dbms/src/Columns/ColumnAggregateFunction.h @@ -268,6 +268,22 @@ class ColumnAggregateFunction final : public COWPtrHelper & /* pos */) const override + { + throw Exception( + "Method deserializeAndAdvancePos is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); + } + + void deserializeAndAdvancePosForColumnArray( + PaddedPODArray & /* pos */, + const IColumn::Offsets & /* array_offsets */) const override + { + throw Exception( + "Method deserializeAndAdvancePosForColumnArray is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); + } + void updateHashWithValue(size_t n, SipHash & hash, const TiDB::TiDBCollatorPtr &, String &) const override; void updateHashWithValues(IColumn::HashValues & hash_values, const TiDB::TiDBCollatorPtr &, String &) diff --git a/dbms/src/Columns/ColumnArray.cpp b/dbms/src/Columns/ColumnArray.cpp index b9fe2f7f784..6b480297f4d 100644 --- a/dbms/src/Columns/ColumnArray.cpp +++ b/dbms/src/Columns/ColumnArray.cpp @@ -30,6 +30,8 @@ #include #include // memcpy +#include "Core/Defines.h" + namespace DB { @@ -399,6 +401,30 @@ void ColumnArray::flushNTAlignBuffer() getData().flushNTAlignBuffer(); } +void ColumnArray::deserializeAndAdvancePos(PaddedPODArray & pos) const +{ + static thread_local IColumn::Offsets offsets; + + size_t size = pos.size(); + offsets.resize(size); + for (size_t i = 0; i < size; ++i) + { + UInt32 len; + tiflash_compiler_builtin_memcpy(&len, pos[i], sizeof(UInt32)); + pos[i] += sizeof(UInt32); + offsets[i] = offsets[i - 1] + len; + } + + getData().deserializeAndAdvancePosForColumnArray(pos, offsets); + + // Free the memory of offsets if the size of pos is too large. + if unlikely (offsets.size() > DEFAULT_BLOCK_SIZE) + { + IColumn::Offsets tmp_offsets; + offsets.swap(tmp_offsets); + } +} + void ColumnArray::updateHashWithValue( size_t n, SipHash & hash, diff --git a/dbms/src/Columns/ColumnArray.h b/dbms/src/Columns/ColumnArray.h index b5b2c6ee9b2..ffdb00d302c 100644 --- a/dbms/src/Columns/ColumnArray.h +++ b/dbms/src/Columns/ColumnArray.h @@ -170,6 +170,17 @@ class ColumnArray final : public COWPtrHelper void flushNTAlignBuffer() override; + void deserializeAndAdvancePos(PaddedPODArray & pos) const override; + + void deserializeAndAdvancePosForColumnArray( + PaddedPODArray & /* pos */, + const IColumn::Offsets & /* array_offsets */) const override + { + throw Exception( + "Method deserializeAndAdvancePosForColumnArray is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); + } + void updateHashWithValue(size_t n, SipHash & hash, const TiDB::TiDBCollatorPtr &, String &) const override; void updateHashWithValues(IColumn::HashValues & hash_values, const TiDB::TiDBCollatorPtr &, String &) const override; diff --git a/dbms/src/Columns/ColumnConst.h b/dbms/src/Columns/ColumnConst.h index 0702bebe5ce..1baed2636f6 100644 --- a/dbms/src/Columns/ColumnConst.h +++ b/dbms/src/Columns/ColumnConst.h @@ -210,6 +210,22 @@ class ColumnConst final : public COWPtrHelper throw Exception("Method flushNTAlignBuffer is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } + void deserializeAndAdvancePos(PaddedPODArray & /* pos */) const override + { + throw Exception( + "Method deserializeAndAdvancePos is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); + } + + void deserializeAndAdvancePosForColumnArray( + PaddedPODArray & /* pos */, + const IColumn::Offsets & /* array_offsets */) const override + { + throw Exception( + "Method deserializeAndAdvancePosForColumnArray is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); + } + void updateHashWithValue( size_t, SipHash & hash, diff --git a/dbms/src/Columns/ColumnDecimal.cpp b/dbms/src/Columns/ColumnDecimal.cpp index 356c75937fa..872b874d19a 100644 --- a/dbms/src/Columns/ColumnDecimal.cpp +++ b/dbms/src/Columns/ColumnDecimal.cpp @@ -533,6 +533,9 @@ void ColumnDecimal::deserializeAndInsertFromPosForColumnArray( const IColumn::Offsets & array_offsets, bool use_nt_align_buffer [[maybe_unused]]) { + // Check if pos is empty is necessary. + // If pos is not empty, then array_offsets is not empty either due to pos.size() <= array_offsets.size(). + // Then reading array_offsets[-1] and array_offsets.back() is valid. if unlikely (pos.empty()) return; RUNTIME_CHECK_MSG( @@ -592,6 +595,24 @@ void ColumnDecimal::flushNTAlignBuffer() #endif } +template +void ColumnDecimal::deserializeAndAdvancePosForColumnArray( + PaddedPODArray & pos, + const IColumn::Offsets & array_offsets) const +{ + RUNTIME_CHECK_MSG( + pos.size() == array_offsets.size(), + "size of pos({}) != size of array_offsets({})", + pos.size(), + array_offsets.size()); + size_t size = pos.size(); + for (size_t i = 0; i < size; ++i) + { + size_t len = array_offsets[i] - array_offsets[i - 1]; + pos[i] += len * sizeof(T); + } +} + template UInt64 ColumnDecimal::get64(size_t n) const { diff --git a/dbms/src/Columns/ColumnDecimal.h b/dbms/src/Columns/ColumnDecimal.h index 9de8ce0aeb7..6a1394368ba 100644 --- a/dbms/src/Columns/ColumnDecimal.h +++ b/dbms/src/Columns/ColumnDecimal.h @@ -219,6 +219,14 @@ class ColumnDecimal final : public COWPtrHelper & pos) const override + { + IColumn::advancePosByOffset(pos, sizeof(T)); + } + + void deserializeAndAdvancePosForColumnArray(PaddedPODArray & pos, const IColumn::Offsets & array_offsets) + const override; + void updateHashWithValue(size_t n, SipHash & hash, const TiDB::TiDBCollatorPtr &, String &) const override; void updateHashWithValues(IColumn::HashValues & hash_values, const TiDB::TiDBCollatorPtr &, String &) const override; diff --git a/dbms/src/Columns/ColumnFixedString.cpp b/dbms/src/Columns/ColumnFixedString.cpp index 949ad243785..d3150f03614 100644 --- a/dbms/src/Columns/ColumnFixedString.cpp +++ b/dbms/src/Columns/ColumnFixedString.cpp @@ -403,6 +403,23 @@ void ColumnFixedString::deserializeAndInsertFromPosForColumnArray( } } +void ColumnFixedString::deserializeAndAdvancePosForColumnArray( + PaddedPODArray & pos, + const IColumn::Offsets & array_offsets) const +{ + RUNTIME_CHECK_MSG( + pos.size() == array_offsets.size(), + "size of pos({}) != size of array_offsets({})", + pos.size(), + array_offsets.size()); + size_t size = pos.size(); + for (size_t i = 0; i < size; ++i) + { + size_t len = array_offsets[i] - array_offsets[i - 1]; + pos[i] += n * len; + } +} + void ColumnFixedString::updateHashWithValue(size_t index, SipHash & hash, const TiDB::TiDBCollatorPtr &, String &) const { hash.update(reinterpret_cast(&chars[n * index]), n); diff --git a/dbms/src/Columns/ColumnFixedString.h b/dbms/src/Columns/ColumnFixedString.h index 736f3b79d2e..5b28a7097ef 100644 --- a/dbms/src/Columns/ColumnFixedString.h +++ b/dbms/src/Columns/ColumnFixedString.h @@ -169,6 +169,11 @@ class ColumnFixedString final : public COWPtrHelper void flushNTAlignBuffer() override {} + void deserializeAndAdvancePos(PaddedPODArray & pos) const override { IColumn::advancePosByOffset(pos, n); } + + void deserializeAndAdvancePosForColumnArray(PaddedPODArray & pos, const IColumn::Offsets & array_offsets) + const override; + void updateHashWithValue(size_t index, SipHash & hash, const TiDB::TiDBCollatorPtr &, String &) const override; void updateHashWithValues(IColumn::HashValues & hash_values, const TiDB::TiDBCollatorPtr &, String &) diff --git a/dbms/src/Columns/ColumnFunction.h b/dbms/src/Columns/ColumnFunction.h index 4a0b093ee13..9725a2d43e6 100644 --- a/dbms/src/Columns/ColumnFunction.h +++ b/dbms/src/Columns/ColumnFunction.h @@ -221,6 +221,22 @@ class ColumnFunction final : public COWPtrHelper throw Exception("Method flushNTAlignBuffer is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } + void deserializeAndAdvancePos(PaddedPODArray & /* pos */) const override + { + throw Exception( + "Method deserializeAndAdvancePos is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); + } + + void deserializeAndAdvancePosForColumnArray( + PaddedPODArray & /* pos */, + const IColumn::Offsets & /* array_offsets */) const override + { + throw Exception( + "Method deserializeAndAdvancePosForColumnArray is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); + } + void updateHashWithValue(size_t, SipHash &, const TiDB::TiDBCollatorPtr &, String &) const override { throw Exception("updateHashWithValue is not implemented for " + getName(), ErrorCodes::NOT_IMPLEMENTED); diff --git a/dbms/src/Columns/ColumnNullable.cpp b/dbms/src/Columns/ColumnNullable.cpp index 9a2b640c5da..2a5403a1278 100644 --- a/dbms/src/Columns/ColumnNullable.cpp +++ b/dbms/src/Columns/ColumnNullable.cpp @@ -392,6 +392,20 @@ void ColumnNullable::flushNTAlignBuffer() getNestedColumn().flushNTAlignBuffer(); } +void ColumnNullable::deserializeAndAdvancePos(PaddedPODArray & pos) const +{ + getNullMapColumn().deserializeAndAdvancePos(pos); + getNestedColumn().deserializeAndAdvancePos(pos); +} + +void ColumnNullable::deserializeAndAdvancePosForColumnArray( + PaddedPODArray & pos, + const IColumn::Offsets & array_offsets) const +{ + getNullMapColumn().deserializeAndAdvancePosForColumnArray(pos, array_offsets); + getNestedColumn().deserializeAndAdvancePosForColumnArray(pos, array_offsets); +} + void ColumnNullable::insertRangeFrom(const IColumn & src, size_t start, size_t length) { const auto & nullable_col = static_cast(src); diff --git a/dbms/src/Columns/ColumnNullable.h b/dbms/src/Columns/ColumnNullable.h index 15f8a515ccd..a8389ed4efc 100644 --- a/dbms/src/Columns/ColumnNullable.h +++ b/dbms/src/Columns/ColumnNullable.h @@ -130,6 +130,11 @@ class ColumnNullable final : public COWPtrHelper void flushNTAlignBuffer() override; + void deserializeAndAdvancePos(PaddedPODArray & pos) const override; + + void deserializeAndAdvancePosForColumnArray(PaddedPODArray & pos, const IColumn::Offsets & array_offsets) + const override; + void insertRangeFrom(const IColumn & src, size_t start, size_t length) override; void insert(const Field & x) override; diff --git a/dbms/src/Columns/ColumnString.cpp b/dbms/src/Columns/ColumnString.cpp index a7ba9d3fbcc..65a9e888aad 100644 --- a/dbms/src/Columns/ColumnString.cpp +++ b/dbms/src/Columns/ColumnString.cpp @@ -1188,6 +1188,41 @@ void ColumnString::flushNTAlignBuffer() #endif } +void ColumnString::deserializeAndAdvancePos(PaddedPODArray & pos) const +{ + size_t size = pos.size(); + for (size_t i = 0; i < size; ++i) + { + UInt32 str_size; + tiflash_compiler_builtin_memcpy(&str_size, pos[i], sizeof(UInt32)); + pos[i] += sizeof(UInt32) + str_size; + } +} + +void ColumnString::deserializeAndAdvancePosForColumnArray( + PaddedPODArray & pos, + const IColumn::Offsets & array_offsets) const +{ + RUNTIME_CHECK_MSG( + pos.size() == array_offsets.size(), + "size of pos({}) != size of array_offsets({})", + pos.size(), + array_offsets.size()); + size_t size = pos.size(); + for (size_t i = 0; i < size; ++i) + { + size_t char_size = 0; + for (size_t j = array_offsets[i - 1]; j < array_offsets[i]; ++j) + { + UInt32 str_size; + tiflash_compiler_builtin_memcpy(&str_size, pos[i], sizeof(UInt32)); + pos[i] += sizeof(UInt32); + char_size += str_size; + } + pos[i] += char_size; + } +} + void updateWeakHash32BinPadding(const std::string_view & view, size_t idx, ColumnString::WeakHash32Info & info) { auto sort_key = BinCollatorSortKey(view.data(), view.size()); diff --git a/dbms/src/Columns/ColumnString.h b/dbms/src/Columns/ColumnString.h index f326dce8672..5e7b5b00a10 100644 --- a/dbms/src/Columns/ColumnString.h +++ b/dbms/src/Columns/ColumnString.h @@ -74,35 +74,21 @@ class ColumnString final : public COWPtrHelper void ALWAYS_INLINE insertFromImpl(const ColumnString & src, size_t n) { - if likely (n != 0) + const size_t size_to_append = src.sizeAt(n); + if (size_to_append == 1) { - const size_t size_to_append = src.offsets[n] - src.offsets[n - 1]; - - if (size_to_append == 1) - { - /// shortcut for empty string - chars.push_back(0); - offsets.push_back(chars.size()); - } - else - { - const size_t old_size = chars.size(); - const size_t offset = src.offsets[n - 1]; - const size_t new_size = old_size + size_to_append; - - chars.resize(new_size); - memcpySmallAllowReadWriteOverflow15(&chars[old_size], &src.chars[offset], size_to_append); - offsets.push_back(new_size); - } + /// shortcut for empty string + chars.push_back(0); + offsets.push_back(chars.size()); } else { const size_t old_size = chars.size(); - const size_t size_to_append = src.offsets[0]; + const size_t offset = src.offsets[n - 1]; const size_t new_size = old_size + size_to_append; chars.resize(new_size); - memcpySmallAllowReadWriteOverflow15(&chars[old_size], &src.chars[0], size_to_append); + memcpySmallAllowReadWriteOverflow15(&chars[old_size], &src.chars[offset], size_to_append); offsets.push_back(new_size); } } @@ -330,6 +316,11 @@ class ColumnString final : public COWPtrHelper void flushNTAlignBuffer() override; + void deserializeAndAdvancePos(PaddedPODArray & pos) const override; + + void deserializeAndAdvancePosForColumnArray(PaddedPODArray & pos, const IColumn::Offsets & array_offsets) + const override; + void updateHashWithValue( size_t n, SipHash & hash, diff --git a/dbms/src/Columns/ColumnTuple.h b/dbms/src/Columns/ColumnTuple.h index cf993910a14..315d96b23cf 100644 --- a/dbms/src/Columns/ColumnTuple.h +++ b/dbms/src/Columns/ColumnTuple.h @@ -202,6 +202,19 @@ class ColumnTuple final : public COWPtrHelper column->assumeMutableRef().flushNTAlignBuffer(); } + void deserializeAndAdvancePos(PaddedPODArray & pos) const override + { + for (const auto & column : columns) + column->deserializeAndAdvancePos(pos); + } + + void deserializeAndAdvancePosForColumnArray(PaddedPODArray & pos, const IColumn::Offsets & array_offsets) + const override + { + for (const auto & column : columns) + column->deserializeAndAdvancePosForColumnArray(pos, array_offsets); + } + void updateHashWithValue(size_t n, SipHash & hash, const TiDB::TiDBCollatorPtr &, String &) const override; void updateHashWithValues(IColumn::HashValues & hash_values, const TiDB::TiDBCollatorPtr &, String &) const override; diff --git a/dbms/src/Columns/ColumnVector.cpp b/dbms/src/Columns/ColumnVector.cpp index 7bd13bbe8ec..0179ae520c3 100644 --- a/dbms/src/Columns/ColumnVector.cpp +++ b/dbms/src/Columns/ColumnVector.cpp @@ -382,6 +382,9 @@ void ColumnVector::deserializeAndInsertFromPosForColumnArray( const IColumn::Offsets & array_offsets, bool use_nt_align_buffer [[maybe_unused]]) { + // Check if pos is empty is necessary. + // If pos is not empty, then array_offsets is not empty either due to pos.size() <= array_offsets.size(). + // Then reading array_offsets[-1] and array_offsets.back() is valid. if unlikely (pos.empty()) return; RUNTIME_CHECK_MSG( @@ -441,6 +444,24 @@ void ColumnVector::flushNTAlignBuffer() #endif } +template +void ColumnVector::deserializeAndAdvancePosForColumnArray( + PaddedPODArray & pos, + const IColumn::Offsets & array_offsets) const +{ + RUNTIME_CHECK_MSG( + pos.size() == array_offsets.size(), + "size of pos({}) != size of array_offsets({})", + pos.size(), + array_offsets.size()); + size_t size = pos.size(); + for (size_t i = 0; i < size; ++i) + { + size_t len = array_offsets[i] - array_offsets[i - 1]; + pos[i] += len * sizeof(T); + } +} + template void ColumnVector::updateHashWithValue(size_t n, SipHash & hash, const TiDB::TiDBCollatorPtr &, String &) const { diff --git a/dbms/src/Columns/ColumnVector.h b/dbms/src/Columns/ColumnVector.h index 17f6de6747d..1c418f7b75f 100644 --- a/dbms/src/Columns/ColumnVector.h +++ b/dbms/src/Columns/ColumnVector.h @@ -388,6 +388,14 @@ class ColumnVector final : public COWPtrHelper & pos) const override + { + IColumn::advancePosByOffset(pos, sizeof(T)); + } + + void deserializeAndAdvancePosForColumnArray(PaddedPODArray & pos, const IColumn::Offsets & array_offsets) + const override; + void updateHashWithValue(size_t n, SipHash & hash, const TiDB::TiDBCollatorPtr &, String &) const override; void updateHashWithValues(IColumn::HashValues & hash_values, const TiDB::TiDBCollatorPtr &, String &) const override; diff --git a/dbms/src/Columns/IColumn.h b/dbms/src/Columns/IColumn.h index 92c49ef86f4..a302bc28606 100644 --- a/dbms/src/Columns/IColumn.h +++ b/dbms/src/Columns/IColumn.h @@ -322,7 +322,7 @@ class IColumn : public COWPtr String * /* sort_key_container */) const = 0; - /// Deserialize and insert data from pos and forward each pos[i] to the end of serialized data. + /// Deserialize and insert data from pos and advance each pointer in pos to the end of serialized data. /// Note: /// 1. The pos pointer must not be nullptr. /// 2. If use_nt_align_buffer is true and AVX2 is enabled, non-temporal store may be used when data memory is aligned to FULL_VECTOR_SIZE_AVX2(64 bytes). @@ -342,7 +342,7 @@ class IColumn : public COWPtr /// column_ptr->flushNTAlignBuffer(); virtual void deserializeAndInsertFromPos(PaddedPODArray & /* pos */, bool /* use_nt_align_buffer */) = 0; - /// Deserialize and insert data from pos and forward each pos[i] to the end of serialized data. + /// Deserialize and insert data from pos and advance each pointer in pos to the end of serialized data. /// Only called by ColumnArray. /// array_offsets is the offsets of ColumnArray. /// The last pos.size() elements of array_offsets can be used to get the length of elements from each pos. @@ -352,9 +352,26 @@ class IColumn : public COWPtr bool /* use_nt_align_buffer */) = 0; - /// Flush the non-temporal align buffer if any. + /// Flush any remaining data in the non-temporal align buffer into the column data. + /// This function must be called after all deserializeAndInsertFromPos calls when use_nt_align_buffer is enabled. virtual void flushNTAlignBuffer() = 0; + /// Deserialize from pos and advance each pointer in pos to the end of serialized data. + virtual void deserializeAndAdvancePos(PaddedPODArray & /* pos */) const = 0; + /// Advance each pointer in 'pos' by a fixed offset. + static void advancePosByOffset(PaddedPODArray & pos, size_t offset) + { + for (auto & p : pos) + p += offset; + } + + /// Deserialize from pos and advance each pointer in pos to the end of serialized data. + /// Only called by ColumnArray. + virtual void deserializeAndAdvancePosForColumnArray( + PaddedPODArray & /* pos */, + const Offsets & /* array_offsets */) const + = 0; + /// Update state of hash function with value of n-th element. /// On subsequent calls of this method for sequence of column values of arbitary types, /// passed bytes to hash must identify sequence of values unambiguously. diff --git a/dbms/src/Columns/IColumnDummy.h b/dbms/src/Columns/IColumnDummy.h index ac45ff60580..7900184c03e 100644 --- a/dbms/src/Columns/IColumnDummy.h +++ b/dbms/src/Columns/IColumnDummy.h @@ -189,6 +189,22 @@ class IColumnDummy : public IColumn throw Exception("Method flushNTAlignBuffer is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } + void deserializeAndAdvancePos(PaddedPODArray & /* pos */) const override + { + throw Exception( + "Method deserializeAndAdvancePos is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); + } + + void deserializeAndAdvancePosForColumnArray( + PaddedPODArray & /* pos */, + const IColumn::Offsets & /* array_offsets */) const override + { + throw Exception( + "Method deserializeAndAdvancePosForColumnArray is not supported for " + getName(), + ErrorCodes::NOT_IMPLEMENTED); + } + void updateHashWithValue(size_t /*n*/, SipHash & /*hash*/, const TiDB::TiDBCollatorPtr &, String &) const override {} diff --git a/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp b/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp index fe5f652328e..4e58b978270 100644 --- a/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp +++ b/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp @@ -211,24 +211,12 @@ class TestColumnSerializeDeserialize : public ::testing::Test } } - static void testSerializeAndDeserialize( - const ColumnPtr & column_ptr, - bool compare_semantics = false, - const TiDB::TiDBCollatorPtr & collator = nullptr, - String * sort_key_container = nullptr) + static void testSerializeAndDeserialize(const ColumnPtr & column_ptr) { - if (compare_semantics) - { - doTestSerializeAndDeserializeForCmp(column_ptr, true, collator, sort_key_container); - doTestSerializeAndDeserializeForCmp(column_ptr, false, collator, sort_key_container); - } - else - { - doTestSerializeAndDeserialize(column_ptr, false); - doTestSerializeAndDeserialize2(column_ptr, false); - doTestSerializeAndDeserialize(column_ptr, true); - doTestSerializeAndDeserialize2(column_ptr, true); - } + doTestSerializeAndDeserialize(column_ptr, false); + doTestSerializeAndDeserialize2(column_ptr, false); + doTestSerializeAndDeserialize(column_ptr, true); + doTestSerializeAndDeserialize2(column_ptr, true); } static void doTestSerializeAndDeserialize(const ColumnPtr & column_ptr, bool use_nt_align_buffer) @@ -255,7 +243,11 @@ class TestColumnSerializeDeserialize : public ::testing::Test auto new_col_ptr = column_ptr->cloneEmpty(); if (use_nt_align_buffer) new_col_ptr->reserveAlign(byte_size.size(), FULL_VECTOR_SIZE_AVX2); + + pos.assign(ori_pos); new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndAdvancePos(pos); + ASSERT_EQ(pos, ori_pos); current_size = 0; pos.clear(); @@ -272,7 +264,10 @@ class TestColumnSerializeDeserialize : public ::testing::Test pos.resize(pos.size() - 1); ori_pos.resize(ori_pos.size() - 1); + pos.assign(ori_pos); new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndAdvancePos(pos); + ASSERT_EQ(pos, ori_pos); current_size = 0; pos.clear(); @@ -286,7 +281,11 @@ class TestColumnSerializeDeserialize : public ::testing::Test ori_pos.push_back(ptr); column_ptr->serializeToPos(pos, 0, byte_size.size(), true); + pos.assign(ori_pos); new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndAdvancePos(pos); + ASSERT_EQ(pos, ori_pos); + if (use_nt_align_buffer) new_col_ptr->flushNTAlignBuffer(); @@ -327,7 +326,11 @@ class TestColumnSerializeDeserialize : public ::testing::Test auto new_col_ptr = column_ptr->cloneEmpty(); if (use_nt_align_buffer) new_col_ptr->reserveAlign(byte_size.size(), FULL_VECTOR_SIZE_AVX2); + + pos.assign(ori_pos); new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndAdvancePos(pos); + ASSERT_EQ(pos, ori_pos); current_size = 0; pos.clear(); @@ -340,7 +343,11 @@ class TestColumnSerializeDeserialize : public ::testing::Test for (auto * ptr : pos) ori_pos.push_back(ptr); column_ptr->serializeToPos(pos, byte_size.size() / 2 - 1, byte_size.size() - byte_size.size() / 2 + 1, false); + + pos.assign(ori_pos); new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndAdvancePos(pos); + ASSERT_EQ(pos, ori_pos); current_size = 0; pos.clear(); @@ -354,7 +361,11 @@ class TestColumnSerializeDeserialize : public ::testing::Test ori_pos.push_back(ptr); column_ptr->serializeToPos(pos, 0, byte_size.size(), true); + pos.assign(ori_pos); new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndAdvancePos(pos); + ASSERT_EQ(pos, ori_pos); + if (use_nt_align_buffer) new_col_ptr->flushNTAlignBuffer(); @@ -365,6 +376,15 @@ class TestColumnSerializeDeserialize : public ::testing::Test ASSERT_COLUMN_EQ(std::move(result_col_ptr), std::move(new_col_ptr)); } + static void testSerializeAndDeserializeForCmp( + const ColumnPtr & column_ptr, + const TiDB::TiDBCollatorPtr & collator = nullptr, + String * sort_key_container = nullptr) + { + doTestSerializeAndDeserializeForCmp(column_ptr, true, collator, sort_key_container); + doTestSerializeAndDeserializeForCmp(column_ptr, false, collator, sort_key_container); + } + static void doTestSerializeAndDeserializeForCmp( const ColumnPtr & column_ptr, bool use_nt_align_buffer, @@ -393,7 +413,11 @@ class TestColumnSerializeDeserialize : public ::testing::Test auto new_col_ptr = column_ptr->cloneEmpty(); if (use_nt_align_buffer) new_col_ptr->reserveAlign(byte_size.size(), FULL_VECTOR_SIZE_AVX2); + + pos.assign(ori_pos); new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndAdvancePos(pos); + ASSERT_EQ(pos, ori_pos); current_size = 0; pos.clear(); @@ -414,7 +438,10 @@ class TestColumnSerializeDeserialize : public ::testing::Test collator, sort_key_container); + pos.assign(ori_pos); new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndAdvancePos(pos); + ASSERT_EQ(pos, ori_pos); current_size = 0; pos.clear(); @@ -428,7 +455,11 @@ class TestColumnSerializeDeserialize : public ::testing::Test ori_pos.push_back(ptr); column_ptr->serializeToPosForCmp(pos, 0, byte_size.size(), false, nullptr, collator, sort_key_container); + + pos.assign(ori_pos); new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndAdvancePos(pos); + ASSERT_EQ(pos, ori_pos); if (use_nt_align_buffer) new_col_ptr->flushNTAlignBuffer(); @@ -450,7 +481,7 @@ try auto col_vector_1 = createColumn({1}).column; testCountSerializeByteSize(col_vector_1, {4}); testSerializeAndDeserialize(col_vector_1); - testSerializeAndDeserialize(col_vector_1, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_vector_1, nullptr, nullptr); auto col_vector = createColumn({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}).column; testCountSerializeByteSize(col_vector, {8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8}); @@ -458,7 +489,7 @@ try testCountSerializeByteSizeForColumnArray(col_vector, col_offsets, {8, 16, 24, 32, 48, 16}); testSerializeAndDeserialize(col_vector); - testSerializeAndDeserialize(col_vector, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_vector, nullptr, nullptr); } CATCH @@ -468,7 +499,7 @@ try auto col_decimal_1 = createColumn(std::make_tuple(10, 3), {"1234567.333"}).column; testCountSerializeByteSize(col_decimal_1, {16}); testSerializeAndDeserialize(col_decimal_1); - testSerializeAndDeserialize(col_decimal_1, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_decimal_1, nullptr, nullptr); auto col_decimal = createColumn( std::make_tuple(8, 2), @@ -482,7 +513,7 @@ try testCountSerializeByteSizeForColumnArray(col_decimal, col_offsets, {4, 8, 12, 6 * 4, 21 * 4}); testSerializeAndDeserialize(col_decimal); - testSerializeAndDeserialize(col_decimal, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_decimal, nullptr, nullptr); auto col_decimal_256 = createColumn( std::make_tuple(61, 4), @@ -525,7 +556,7 @@ try testCountSerializeByteSizeForColumnArray(col_decimal_256, col_offsets, {48, 2 * 48, 3 * 48, 6 * 48, 21 * 48}); testSerializeAndDeserialize(col_decimal_256); - testSerializeAndDeserialize(col_decimal_256, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_decimal_256, nullptr, nullptr); // Also test row-base interface for ColumnDecimal. Arena arena; @@ -553,9 +584,9 @@ try auto col_string_1 = createColumn({"sdafyuwer123"}).column; testCountSerializeByteSize(col_string_1, {4 + 13}); testSerializeAndDeserialize(col_string_1); - testSerializeAndDeserialize(col_string_1, true, collator_utf8_bin, &sort_key_container); - testSerializeAndDeserialize(col_string_1, true, collator_utf8_general_ci, &sort_key_container); - testSerializeAndDeserialize(col_string_1, true, collator_utf8_unicode_ci, &sort_key_container); + testSerializeAndDeserializeForCmp(col_string_1, collator_utf8_bin, &sort_key_container); + testSerializeAndDeserializeForCmp(col_string_1, collator_utf8_general_ci, &sort_key_container); + testSerializeAndDeserializeForCmp(col_string_1, collator_utf8_unicode_ci, &sort_key_container); auto col_string = createColumn({"123", "1234567890", @@ -597,9 +628,9 @@ try testSerializeAndDeserialize(col_string); - testSerializeAndDeserialize(col_string, true, collator_utf8_bin, &sort_key_container); - testSerializeAndDeserialize(col_string, true, collator_utf8_general_ci, &sort_key_container); - testSerializeAndDeserialize(col_string, true, collator_utf8_unicode_ci, &sort_key_container); + testSerializeAndDeserializeForCmp(col_string, collator_utf8_bin, &sort_key_container); + testSerializeAndDeserializeForCmp(col_string, collator_utf8_general_ci, &sort_key_container); + testSerializeAndDeserializeForCmp(col_string, collator_utf8_unicode_ci, &sort_key_container); } CATCH @@ -620,7 +651,7 @@ try testSerializeAndDeserialize(col_fixed_string); // ColumnFixedString doesn't support serialize/deserialize with collator for now. - testSerializeAndDeserialize(col_fixed_string, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_fixed_string, nullptr, nullptr); } CATCH @@ -643,7 +674,7 @@ try testCountSerializeByteSize(col_nullable_decimal_0, {49, 49, 49, 49, 49}); testSerializeAndDeserialize(col_nullable_decimal_0); testCountSerializeByteSize(col_nullable_decimal_0, {49, 49, 49, 49, 49}, true, nullptr); - testSerializeAndDeserialize(col_nullable_decimal_0, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_nullable_decimal_0, nullptr, nullptr); } // ColumnNullable(ColumnDecimal32) @@ -661,7 +692,7 @@ try testCountSerializeByteSize(col_nullable_decimal_1, {1 + 4, 1 + 4, 1 + 4, 1 + 4}); testSerializeAndDeserialize(col_nullable_decimal_1); testCountSerializeByteSize(col_nullable_decimal_1, {1 + 4, 1 + 4, 1 + 4, 1 + 4}, true, nullptr); - testSerializeAndDeserialize(col_nullable_decimal_1, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_nullable_decimal_1, nullptr, nullptr); } // ColumnNullable(ColumnVector) @@ -670,7 +701,7 @@ try testCountSerializeByteSize(col_nullable_vec, {9, 9, 9, 9, 9, 9}); testSerializeAndDeserialize(col_nullable_vec); testCountSerializeByteSize(col_nullable_vec, {9, 9, 9, 9, 9, 9}, true, nullptr); - testSerializeAndDeserialize(col_nullable_vec, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_nullable_vec, nullptr, nullptr); } String sort_key_container; @@ -687,9 +718,9 @@ try testSerializeAndDeserialize(col_nullable_string); // 5: 1(null) + 4(sizeof(UInt32)) testCountSerializeByteSize(col_nullable_string, {5 + 4, 5 + 1, 5 + 3, 5 + 1, 5 + 5, 5 + 1}, true, nullptr); - testSerializeAndDeserialize(col_nullable_string, true, collator_utf8_bin, &sort_key_container); - testSerializeAndDeserialize(col_nullable_string, true, collator_utf8_general_ci, &sort_key_container); - testSerializeAndDeserialize(col_nullable_string, true, collator_utf8_unicode_ci, &sort_key_container); + testSerializeAndDeserializeForCmp(col_nullable_string, collator_utf8_bin, &sort_key_container); + testSerializeAndDeserializeForCmp(col_nullable_string, collator_utf8_general_ci, &sort_key_container); + testSerializeAndDeserializeForCmp(col_nullable_string, collator_utf8_unicode_ci, &sort_key_container); } // ColumnNullable(ColumnFixedString) @@ -710,7 +741,7 @@ try {1 + 2, 1 + 2, 1 + 2, 1 + 2, 1 + 2, 1 + 2}, true, nullptr); - testSerializeAndDeserialize(col_nullable_fixed_string, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_nullable_fixed_string, nullptr, nullptr); } auto col_offsets = createColumn({1, 3, 6}).column; @@ -732,7 +763,7 @@ try testCountSerializeByteSize(col_nullable_array_dec, {1 + 4 + 48, 1 + 4 + 48 * 2, 1 + 4 + 48 * 3}); testSerializeAndDeserialize(col_nullable_array_dec); testCountSerializeByteSize(col_nullable_array_dec, {1 + 4, 1 + 4 + 48 * 2, 1 + 4}, true, nullptr); - testSerializeAndDeserialize(col_nullable_array_dec, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_nullable_array_dec, nullptr, nullptr); } // ColumnNullable(ColumnArray(ColumnDecimal128) @@ -753,7 +784,7 @@ try testCountSerializeByteSize(col_nullable_array_dec_1, {1 + 4 + 16, 1 + 4 + 2 * 16, 1 + 4 + 3 * 16}); testSerializeAndDeserialize(col_nullable_array_dec_1); testCountSerializeByteSize(col_nullable_array_dec_1, {1 + 4, 1 + 4 + 2 * 16, 1 + 4}, true, nullptr); - testSerializeAndDeserialize(col_nullable_array_dec_1, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_nullable_array_dec_1, nullptr, nullptr); } // ColumnNullable(ColumnArray(ColumnVector)) @@ -764,7 +795,7 @@ try testCountSerializeByteSize(col_nullable_array_vec, {1 + 4 + 4, 1 + 4 + 8, 1 + 4 + 12}); testSerializeAndDeserialize(col_nullable_array_vec); testCountSerializeByteSize(col_nullable_array_vec, {1 + 4, 1 + 4 + 8, 1 + 4}, true, nullptr); - testSerializeAndDeserialize(col_nullable_array_vec, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_nullable_array_vec, nullptr, nullptr); } // ColumnNullable(ColumnArray(ColumnString)) @@ -776,7 +807,7 @@ try testCountSerializeByteSize(col_nullable_array_string, {1 + 4 + 4 + 4, 1 + 4 + 4 * 2 + 5, 1 + 4 + 4 * 3 + 11}); testSerializeAndDeserialize(col_nullable_array_string); testCountSerializeByteSize(col_nullable_array_string, {1 + 4, 1 + 4 + 4 * 2 + 5, 1 + 4}, true, nullptr); - testSerializeAndDeserialize(col_nullable_array_string, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_nullable_array_string, nullptr, nullptr); } // ColumnNullable(ColumnArray(ColumnString)) with utf8 char. { @@ -786,9 +817,9 @@ try auto col_array_string = ColumnArray::create(col_string, col_offsets); auto col_nullable_array_string = ColumnNullable::create(col_array_string, createColumn({1, 0, 1}).column); - testSerializeAndDeserialize(col_nullable_array_string, true, collator_utf8_bin, &sort_key_container); - testSerializeAndDeserialize(col_nullable_array_string, true, collator_utf8_general_ci, &sort_key_container); - testSerializeAndDeserialize(col_nullable_array_string, true, collator_utf8_unicode_ci, &sort_key_container); + testSerializeAndDeserializeForCmp(col_nullable_array_string, collator_utf8_bin, &sort_key_container); + testSerializeAndDeserializeForCmp(col_nullable_array_string, collator_utf8_general_ci, &sort_key_container); + testSerializeAndDeserializeForCmp(col_nullable_array_string, collator_utf8_unicode_ci, &sort_key_container); } // ColumnNullable(ColumnArray(ColumnFixedString)) @@ -807,7 +838,7 @@ try testCountSerializeByteSize(col_nullable_array_fixed_string, {1 + 4 + 2, 1 + 4 + 4, 1 + 4 + 6}); testSerializeAndDeserialize(col_nullable_array_fixed_string); testCountSerializeByteSize(col_nullable_array_fixed_string, {1 + 4, 1 + 4 + 4, 1 + 4}, true, nullptr); - testSerializeAndDeserialize(col_nullable_array_fixed_string, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_nullable_array_fixed_string, nullptr, nullptr); } // ColumnNullable(ColumnNullable(xxx)) not support. @@ -821,9 +852,9 @@ try // 1 + 4 + 2 + 8 + 4, // 1 + 4 + 3 + 12 + 7}, true, nullptr); // testSerializeAndDeserialize(col_nullable_array_string); - // testSerializeAndDeserialize(col_nullable_array_string, true, collator_utf8_bin, &sort_key_container); - // testSerializeAndDeserialize(col_nullable_array_string, true, collator_utf8_general_ci, &sort_key_container); - // testSerializeAndDeserialize(col_nullable_array_string, true, collator_utf8_unicode_ci, &sort_key_container); + // testSerializeAndDeserializeForCmp(col_nullable_array_string, collator_utf8_bin, &sort_key_container); + // testSerializeAndDeserializeForCmp(col_nullable_array_string, collator_utf8_general_ci, &sort_key_container); + // testSerializeAndDeserializeForCmp(col_nullable_array_string, collator_utf8_unicode_ci, &sort_key_container); } CATCH @@ -836,7 +867,7 @@ try auto col_array_vec = ColumnArray::create(col_vector, col_offsets); testCountSerializeByteSize(col_array_vec, {4 + 4, 4 + 8, 4 + 12}); testSerializeAndDeserialize(col_array_vec); - testSerializeAndDeserialize(col_array_vec, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_array_vec, nullptr, nullptr); // ColumnArray(ColumnString) String sort_key_container; @@ -849,9 +880,9 @@ try auto col_array_string = ColumnArray::create(col_string, col_offsets); testCountSerializeByteSize(col_array_string, {4 + 4 + 4, 4 + 8 + 5, 4 + 12 + 11}); testSerializeAndDeserialize(col_array_string); - testSerializeAndDeserialize(col_array_string, true, collator_utf8_bin, &sort_key_container); - testSerializeAndDeserialize(col_array_string, true, collator_utf8_general_ci, &sort_key_container); - testSerializeAndDeserialize(col_array_string, true, collator_utf8_unicode_ci, &sort_key_container); + testSerializeAndDeserializeForCmp(col_array_string, collator_utf8_bin, &sort_key_container); + testSerializeAndDeserializeForCmp(col_array_string, collator_utf8_general_ci, &sort_key_container); + testSerializeAndDeserializeForCmp(col_array_string, collator_utf8_unicode_ci, &sort_key_container); // ColumnArray(ColumnNullable(ColumnString)) auto col_nullable_string @@ -907,7 +938,7 @@ try col_array_decimal_256, {4 + 3 * 48, 4 + 5 * 48, 4 + 7 * 48, 4 + 5 * 48, 4 + 10 * 48, 4 + 48, 4 + 48, 4 + 48}); testSerializeAndDeserialize(col_array_decimal_256); - testSerializeAndDeserialize(col_array_decimal_256, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_array_decimal_256, nullptr, nullptr); // ColumnArray(ColumnFixedString) auto col_fixed_string_mut = ColumnFixedString::create(2); @@ -921,7 +952,7 @@ try auto col_array_fixed_string = ColumnArray::create(col_fixed_string, col_offsets); testCountSerializeByteSize(col_array_fixed_string, {4 + 2, 4 + 4, 4 + 6}); testSerializeAndDeserialize(col_array_fixed_string); - testSerializeAndDeserialize(col_array_fixed_string, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_array_fixed_string, nullptr, nullptr); } CATCH @@ -943,9 +974,9 @@ try = TiDB::ITiDBCollator::getCollator(TiDB::ITiDBCollator::UTF8MB4_GENERAL_CI); TiDB::TiDBCollatorPtr collator_utf8_unicode_ci = TiDB::ITiDBCollator::getCollator(TiDB::ITiDBCollator::UTF8MB4_UNICODE_CI); - testSerializeAndDeserialize(col_tuple, true, collator_utf8_bin, &sort_key_container); - testSerializeAndDeserialize(col_tuple, true, collator_utf8_general_ci, &sort_key_container); - testSerializeAndDeserialize(col_tuple, true, collator_utf8_unicode_ci, &sort_key_container); + testSerializeAndDeserializeForCmp(col_tuple, collator_utf8_bin, &sort_key_container); + testSerializeAndDeserializeForCmp(col_tuple, collator_utf8_general_ci, &sort_key_container); + testSerializeAndDeserializeForCmp(col_tuple, collator_utf8_unicode_ci, &sort_key_container); } CATCH @@ -1015,9 +1046,9 @@ try // ColumnString String sort_key_container; - testSerializeAndDeserialize(col_string, true, collator_utf8_bin, &sort_key_container); - testSerializeAndDeserialize(col_string, true, collator_utf8_general_ci, &sort_key_container); - testSerializeAndDeserialize(col_string, true, collator_utf8_unicode_ci, &sort_key_container); + testSerializeAndDeserializeForCmp(col_string, collator_utf8_bin, &sort_key_container); + testSerializeAndDeserializeForCmp(col_string, collator_utf8_general_ci, &sort_key_container); + testSerializeAndDeserializeForCmp(col_string, collator_utf8_unicode_ci, &sort_key_container); } CATCH @@ -1033,7 +1064,7 @@ try }; auto col_dec256 = createColumn(std::make_tuple(40, 6), decimal_vals).column; testSerializeAndDeserialize(col_dec256); - testSerializeAndDeserialize(col_dec256, true, nullptr, nullptr); + testSerializeAndDeserializeForCmp(col_dec256, nullptr, nullptr); } CATCH From 4d53325fda0294806671dadda74992a1f6e13feb Mon Sep 17 00:00:00 2001 From: gengliqi Date: Wed, 26 Mar 2025 19:55:40 +0800 Subject: [PATCH 12/84] u Signed-off-by: gengliqi --- dbms/src/Columns/ColumnArray.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dbms/src/Columns/ColumnArray.cpp b/dbms/src/Columns/ColumnArray.cpp index 6b480297f4d..c607b16af5b 100644 --- a/dbms/src/Columns/ColumnArray.cpp +++ b/dbms/src/Columns/ColumnArray.cpp @@ -24,14 +24,13 @@ #include #include #include +#include #include #include #include #include #include // memcpy -#include "Core/Defines.h" - namespace DB { From 772f6c2ab1d02155d70a67bf9611f0f2fb98fb3a Mon Sep 17 00:00:00 2001 From: gengliqi Date: Wed, 26 Mar 2025 21:25:57 +0800 Subject: [PATCH 13/84] fix Signed-off-by: gengliqi --- contrib/tipb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contrib/tipb b/contrib/tipb index 07c1d4cf432..b4de785594f 160000 --- a/contrib/tipb +++ b/contrib/tipb @@ -1 +1 @@ -Subproject commit 07c1d4cf43236a98d6299afbe864b628b52c0eae +Subproject commit b4de785594f587062406fe484943fc54b8054733 From 933ed11dbe2a5fc6dd68a72ad876d5ce21c53bac Mon Sep 17 00:00:00 2001 From: gengliqi Date: Fri, 4 Apr 2025 14:01:24 +0800 Subject: [PATCH 14/84] update tests Signed-off-by: gengliqi --- dbms/src/Flash/tests/gtest_join_executor.cpp | 30 +++++++++---------- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 7 ++++- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/dbms/src/Flash/tests/gtest_join_executor.cpp b/dbms/src/Flash/tests/gtest_join_executor.cpp index 13539fc8d75..c754b69222e 100644 --- a/dbms/src/Flash/tests/gtest_join_executor.cpp +++ b/dbms/src/Flash/tests/gtest_join_executor.cpp @@ -2190,11 +2190,11 @@ try context.addMockTable( {"test_db", "lj_r_table"}, { - {"r1", TiDB::TP::TypeString}, - {"k1", TiDB::TP::TypeLongLong}, - {"k2", TiDB::TP::TypeShort}, - {"r2", TiDB::TP::TypeString}, - {"r3", TiDB::TP::TypeLong}, + {"r1", TiDB::TP::TypeString, false}, + {"k1", TiDB::TP::TypeLongLong, false}, + {"k2", TiDB::TP::TypeShort, false}, + {"r2", TiDB::TP::TypeString, false}, + {"r3", TiDB::TP::TypeLong, false}, }, { toVec("r1", {"apple", "banana", "cat", "dog", "elephant", "frag"}), @@ -2207,12 +2207,12 @@ try context.addMockTable( {"test_db", "lj_l_table"}, { - {"l1", TiDB::TP::TypeString}, - {"k1", TiDB::TP::TypeLongLong}, - {"k2", TiDB::TP::TypeShort}, - {"l2", TiDB::TP::TypeLong}, - {"l3", TiDB::TP::TypeLong}, - {"l4", TiDB::TP::TypeLongLong}, + {"l1", TiDB::TP::TypeString, false}, + {"k1", TiDB::TP::TypeLongLong, true}, + {"k2", TiDB::TP::TypeShort, true}, + {"l2", TiDB::TP::TypeLong, false}, + {"l3", TiDB::TP::TypeLong, true}, + {"l4", TiDB::TP::TypeLongLong, true}, }, { toVec("l1", {"AAA", "BBB", "CCC", "DDD", "EEE", "FFF", "GGG", "HHH", "III", "JJJ", "KKK", "LLL"}), @@ -2286,7 +2286,7 @@ try toNullableVec({1, 2, {}, 3, 5, {}, {}, {}, {}, 4, 3, 5, 4, 1, 2, 3, 5}), toNullableVec({1, 1, {}, 2, 2, {}, {}, {}, {}, 3, 2, 2, 3, 1, 1, 2, 2}), toNullableVec({1, 1, {}, 2, 2, 9, 2, 3, 3, 3, 2, 2, 3, 1, 1, 2, 2}), - toNullableVec( + toVec( {"AAA", "AAA", "BBB", @@ -2304,7 +2304,7 @@ try "KKK", "LLL", "LLL"}), - toNullableVec({1, 1, 2, 3, 3, 4, 5, 6, 6, 6, 7, 7, 8, 9, 9, 10, 10}), + toVec({1, 1, 2, 3, 3, 4, 5, 6, 6, 6, 7, 7, 8, 9, 9, 10, 10}), toNullableVec({3, 3, 1, 2, 2, 0, 2, 3, 3, {}, 4, 4, 5, 0, 0, 6, 6}), }); WRAP_FOR_JOIN_TEST_END @@ -2338,9 +2338,9 @@ try toNullableVec({1, 2, {}, {}, {}, {}, {}, {}, {}, 3, 4, {}, 3, 5}), toNullableVec({1, 1, {}, {}, {}, {}, {}, {}, {}, 2, 3, {}, 2, 2}), toNullableVec({1, 1, {}, 2, 9, 2, 3, 3, 3, 2, 3, 1, 2, 2}), - toNullableVec( + toVec( {"AAA", "AAA", "BBB", "CCC", "DDD", "EEE", "FFF", "GGG", "HHH", "III", "JJJ", "KKK", "LLL", "LLL"}), - toNullableVec({1, 1, 2, 3, 4, 5, 6, 6, 6, 7, 8, 9, 10, 10}), + toVec({1, 1, 2, 3, 4, 5, 6, 6, 6, 7, 8, 9, 10, 10}), toNullableVec({3, 3, 1, 2, 0, 2, 3, 3, {}, 4, 5, 0, 6, 6}), }); WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 6a5244ce470..805d890841c 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -345,9 +345,12 @@ Block JoinProbeBlockHelper::probeImpl(JoinProbeContext & context, JoinProbeWorke MutableColumns added_columns; if constexpr (late_materialization) { - for (auto [column_index, _] : row_layout.raw_key_column_indexes) + for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) + { added_columns.emplace_back( wd.result_block.safeGetByPosition(left_columns + column_index).column->assumeMutable()); + RUNTIME_CHECK(added_columns.back()->isColumnNullable() == is_nullable); + } for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) { size_t column_index = row_layout.other_column_indexes[i].first; @@ -360,6 +363,8 @@ Block JoinProbeBlockHelper::probeImpl(JoinProbeContext & context, JoinProbeWorke added_columns.resize(right_columns); for (size_t i = 0; i < right_columns; ++i) added_columns[i] = wd.result_block.safeGetByPosition(left_columns + i).column->assumeMutable(); + for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) + RUNTIME_CHECK(added_columns[column_index]->isColumnNullable() == is_nullable); } Stopwatch watch; From b535cd70c49afcb2b452fec999e4af75bc48e65b Mon Sep 17 00:00:00 2001 From: gengliqi Date: Fri, 4 Apr 2025 15:04:02 +0800 Subject: [PATCH 15/84] update tests Signed-off-by: gengliqi --- dbms/src/Flash/tests/gtest_join_executor.cpp | 149 +++++++++++++++--- dbms/src/Interpreters/JoinV2/HashJoin.h | 2 - .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 3 +- 3 files changed, 126 insertions(+), 28 deletions(-) diff --git a/dbms/src/Flash/tests/gtest_join_executor.cpp b/dbms/src/Flash/tests/gtest_join_executor.cpp index c754b69222e..d26e04e8e65 100644 --- a/dbms/src/Flash/tests/gtest_join_executor.cpp +++ b/dbms/src/Flash/tests/gtest_join_executor.cpp @@ -2163,28 +2163,7 @@ try } CATCH -TEST_F(JoinExecutorTestRunner, LeftJoinAggWithOtherCondition) -try -{ - auto request - = context.scan("test_db", "l_table") - .join( - context.scan("test_db", "r_table"), - tipb::JoinType::TypeLeftOuterJoin, - {col("join_c")}, - {}, - {}, - {And(lt(col("l_table.s"), col("r_table.s")), eq(col("l_table.join_c"), col("r_table.join_c")))}, - {}) - .aggregation({Count(lit(static_cast(1)))}, {}) - .build(context); - WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_BEGIN - ASSERT_COLUMNS_EQ_UR(genScalarCountResults(2), executeStreams(request, 2)); - WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END -} -CATCH - -TEST_F(JoinExecutorTestRunner, LeftOuterJoin) +TEST_F(JoinExecutorTestRunner, LeftOuterJoinWithNullAndNotNullJoinKey) try { context.addMockTable( @@ -2223,7 +2202,7 @@ try toNullableVec("l4", {3, 1, 2, 0, 2, 3, 3, {}, 4, 5, 0, 6}), }); - // No other condition + // No other condition: lj_l_table left join lj_r_table auto request = context.scan("test_db", "lj_l_table") .join( context.scan("test_db", "lj_r_table"), @@ -2309,7 +2288,73 @@ try }); WRAP_FOR_JOIN_TEST_END - // Has other condition + // No other condition: lj_r_table left join lj_l_table + request = context.scan("test_db", "lj_r_table") + .join( + context.scan("test_db", "lj_l_table"), + tipb::JoinType::TypeLeftOuterJoin, + {col("k2"), col("k1")}, + {}, + {}, + {}, + {}) + .project( + {"lj_r_table.r1", + "lj_r_table.r2", + "lj_r_table.r3", + "lj_r_table.k2", + "lj_l_table.k2", + "lj_l_table.l1", + "lj_l_table.l2", + "lj_l_table.l4"}) + .build(context); + WRAP_FOR_JOIN_TEST_BEGIN + executeAndAssertColumnsEqual( + request, + { + toVec( + {"apple", + "apple", + "banana", + "banana", + "cat", + "cat", + "cat", + "dog", + "dog", + "dog", + "dog", + "elephant", + "elephant", + "elephant", + "frag"}), + toVec( + {"aaa", + "aaa", + "bbb", + "bbb", + "ccc", + "ccc", + "ccc", + "ddd", + "ddd", + "ddd", + "ddd", + "eee", + "eee", + "eee", + "fff"}), + toVec({1, 1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 6}), + toVec({1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 2, 2, 2, 5}), + toNullableVec({1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 2, 2, 2, {}}), + toNullableVec( + {"AAA", "KKK", "AAA", "KKK", "CCC", "LLL", "III", "FFF", "JJJ", "HHH", "GGG", "CCC", "LLL", "III", {}}), + toNullableVec({1, 9, 1, 9, 3, 10, 7, 6, 8, 6, 6, 3, 10, 7, {}}), + toNullableVec({3, 0, 3, 0, 2, 6, 4, 3, 5, {}, 3, 2, 6, 4, {}}), + }); + WRAP_FOR_JOIN_TEST_END + + // Has other condition: lj_l_table left join lj_r_table request = context.scan("test_db", "lj_l_table") .join( context.scan("test_db", "lj_r_table"), @@ -2344,6 +2389,62 @@ try toNullableVec({3, 3, 1, 2, 0, 2, 3, 3, {}, 4, 5, 0, 6, 6}), }); WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END + + // Has other condition: lj_r_table left join lj_l_table + request = context.scan("test_db", "lj_r_table") + .join( + context.scan("test_db", "lj_l_table"), + tipb::JoinType::TypeLeftOuterJoin, + {col("k2"), col("k1")}, + {}, + {}, + {gt(col("lj_l_table.l4"), col("lj_r_table.r3"))}, + {}) + .project( + {"lj_r_table.r1", + "lj_r_table.r2", + "lj_r_table.r3", + "lj_r_table.k2", + "lj_l_table.k2", + "lj_l_table.l1", + "lj_l_table.l2", + "lj_l_table.l4"}) + .build(context); + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_BEGIN + executeAndAssertColumnsEqual( + request, + { + toVec({"apple", "banana", "cat", "cat", "dog", "elephant", "frag"}), + toVec({"aaa", "bbb", "ccc", "ccc", "ddd", "eee", "fff"}), + toVec({1, 2, 3, 3, 4, 5, 6}), + toVec({1, 1, 2, 2, 3, 2, 5}), + toNullableVec({1, 1, 2, 2, 3, 2, {}}), + toNullableVec({"AAA", "AAA", "LLL", "III", "JJJ", "LLL", {}}), + toNullableVec({1, 1, 10, 7, 8, 10, {}}), + toNullableVec({3, 3, 6, 4, 5, 6, {}}), + }); + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END +} +CATCH + +TEST_F(JoinExecutorTestRunner, LeftJoinAggWithOtherCondition) +try +{ + auto request + = context.scan("test_db", "l_table") + .join( + context.scan("test_db", "r_table"), + tipb::JoinType::TypeLeftOuterJoin, + {col("join_c")}, + {}, + {}, + {And(lt(col("l_table.s"), col("r_table.s")), eq(col("l_table.join_c"), col("r_table.join_c")))}, + {}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_BEGIN + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(2), executeStreams(request, 2)); + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END } CATCH diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.h b/dbms/src/Interpreters/JoinV2/HashJoin.h index d00a4e9c661..ee3eb0d239a 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.h +++ b/dbms/src/Interpreters/JoinV2/HashJoin.h @@ -27,8 +27,6 @@ #include #include -#include - namespace DB { diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 805d890841c..b9f90df99cc 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -364,7 +364,7 @@ Block JoinProbeBlockHelper::probeImpl(JoinProbeContext & context, JoinProbeWorke for (size_t i = 0; i < right_columns; ++i) added_columns[i] = wd.result_block.safeGetByPosition(left_columns + i).column->assumeMutable(); for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) - RUNTIME_CHECK(added_columns[column_index]->isColumnNullable() == is_nullable); + RUNTIME_CHECK(added_columns.at(column_index)->isColumnNullable() == is_nullable); } Stopwatch watch; @@ -381,7 +381,6 @@ Block JoinProbeBlockHelper::probeImpl(JoinProbeContext & context, JoinProbeWorke context, wd, added_columns); - wd.probe_hash_table_time += watch.elapsedFromLastTime(); if constexpr (late_materialization) From 1945bf02957285391343510613c5b6fb6695632e Mon Sep 17 00:00:00 2001 From: gengliqi Date: Mon, 7 Apr 2025 16:02:19 +0800 Subject: [PATCH 16/84] address comments Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 6 +++--- .../Interpreters/NullAwareSemiJoinHelper.cpp | 6 +++--- dbms/src/Interpreters/SemiJoinHelper.cpp | 18 +++++------------- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index b9f90df99cc..1f793751d53 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -383,6 +383,9 @@ Block JoinProbeBlockHelper::probeImpl(JoinProbeContext & context, JoinProbeWorke added_columns); wd.probe_hash_table_time += watch.elapsedFromLastTime(); + if (wd.selective_offsets.empty()) + return join->output_block_after_finalize; + if constexpr (late_materialization) { size_t idx = 0; @@ -400,9 +403,6 @@ Block JoinProbeBlockHelper::probeImpl(JoinProbeContext & context, JoinProbeWorke wd.result_block.safeGetByPosition(left_columns + i).column = std::move(added_columns[i]); } - if (wd.selective_offsets.empty()) - return join->output_block_after_finalize; - if constexpr (has_other_condition) { // Always using late materialization for left side diff --git a/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp b/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp index daf5172fa71..a53b5e34df9 100644 --- a/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp +++ b/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp @@ -292,13 +292,13 @@ Block NASemiJoinHelper::genJoinResult(const NameSet & ou { // If the result is true, this row should be kept. // Otherwise, this row should be filtered. - (*filter)[i] = result == SemiJoinResultType::TRUE_VALUE ? 1 : 0; + (*filter)[i] = result == SemiJoinResultType::TRUE_VALUE; rows_for_anti += (*filter)[i]; } else { - Int8 res = result == SemiJoinResultType::TRUE_VALUE ? 1 : 0; - UInt8 is_null = result == SemiJoinResultType::NULL_VALUE ? 1 : 0; + Int8 res = result == SemiJoinResultType::TRUE_VALUE; + UInt8 is_null = result == SemiJoinResultType::NULL_VALUE; left_semi_column_data->push_back(res); left_semi_null_map->push_back(is_null); } diff --git a/dbms/src/Interpreters/SemiJoinHelper.cpp b/dbms/src/Interpreters/SemiJoinHelper.cpp index 918996a7efe..0727e9d9664 100644 --- a/dbms/src/Interpreters/SemiJoinHelper.cpp +++ b/dbms/src/Interpreters/SemiJoinHelper.cpp @@ -429,17 +429,9 @@ Block SemiJoinHelper::genJoinResult(const NameSet & outp auto result = join_result[i].getResult(); if constexpr (KIND == ASTTableJoin::Kind::Semi || KIND == ASTTableJoin::Kind::Anti) { - if (isTrueSemiJoinResult(result)) - { - // If the result is true, this row should be kept. - (*filter)[i] = 1; - ++rows_for_semi_anti; - } - else - { - // If the result is null or false, this row should be filtered. - (*filter)[i] = 0; - } + // If the result is true, this row should be kept. + (*filter)[i] = isTrueSemiJoinResult(result); + rows_for_semi_anti += (*filter)[i]; } else { @@ -449,8 +441,8 @@ Block SemiJoinHelper::genJoinResult(const NameSet & outp } else { - Int8 res = result == SemiJoinResultType::TRUE_VALUE ? 1 : 0; - UInt8 is_null = result == SemiJoinResultType::NULL_VALUE ? 1 : 0; + Int8 res = result == SemiJoinResultType::TRUE_VALUE; + UInt8 is_null = result == SemiJoinResultType::NULL_VALUE; left_semi_column_data->push_back(res); left_semi_null_map->push_back(is_null); } From 270363f7b82b5bed3b0f8321def6b32780607928 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 8 Apr 2025 16:05:57 +0800 Subject: [PATCH 17/84] address comments Signed-off-by: gengliqi --- dbms/src/Columns/ColumnAggregateFunction.h | 5 +++++ dbms/src/Columns/ColumnArray.cpp | 15 ++++++++++----- dbms/src/Columns/ColumnArray.h | 2 ++ dbms/src/Columns/ColumnConst.h | 5 +++++ dbms/src/Columns/ColumnDecimal.h | 2 ++ dbms/src/Columns/ColumnFixedString.h | 2 ++ dbms/src/Columns/ColumnFunction.h | 5 +++++ dbms/src/Columns/ColumnNullable.cpp | 5 +++++ dbms/src/Columns/ColumnNullable.h | 2 ++ dbms/src/Columns/ColumnString.h | 2 ++ dbms/src/Columns/ColumnTuple.h | 8 ++++++++ dbms/src/Columns/ColumnVector.h | 2 ++ dbms/src/Columns/IColumn.h | 4 ++++ dbms/src/Columns/IColumnDummy.h | 5 +++++ dbms/src/Interpreters/JoinV2/HashJoin.cpp | 2 +- dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp | 11 ++--------- 16 files changed, 62 insertions(+), 15 deletions(-) diff --git a/dbms/src/Columns/ColumnAggregateFunction.h b/dbms/src/Columns/ColumnAggregateFunction.h index a3a688db282..873f773df96 100644 --- a/dbms/src/Columns/ColumnAggregateFunction.h +++ b/dbms/src/Columns/ColumnAggregateFunction.h @@ -167,6 +167,11 @@ class ColumnAggregateFunction final : public COWPtrHelper & /* byte_size */) const override { throw Exception("Method countSerializeByteSize is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); diff --git a/dbms/src/Columns/ColumnArray.cpp b/dbms/src/Columns/ColumnArray.cpp index c607b16af5b..66a46035472 100644 --- a/dbms/src/Columns/ColumnArray.cpp +++ b/dbms/src/Columns/ColumnArray.cpp @@ -219,12 +219,9 @@ const char * ColumnArray::deserializeAndInsertFromArena(const char * pos, const return pos; } -void ColumnArray::countSerializeByteSizeForCmp( - PaddedPODArray & byte_size, - const NullMap * nullmap, - const TiDB::TiDBCollatorPtr & collator) const +size_t ColumnArray::serializeByteSize() const { - countSerializeByteSizeImpl(byte_size, nullmap, collator); + return getData().serializeByteSize() + getOffsets().size() * sizeof(UInt32); } void ColumnArray::countSerializeByteSize(PaddedPODArray & byte_size) const @@ -232,6 +229,14 @@ void ColumnArray::countSerializeByteSize(PaddedPODArray & byte_size) con countSerializeByteSizeImpl(byte_size, nullptr, nullptr); } +void ColumnArray::countSerializeByteSizeForCmp( + PaddedPODArray & byte_size, + const NullMap * nullmap, + const TiDB::TiDBCollatorPtr & collator) const +{ + countSerializeByteSizeImpl(byte_size, nullmap, collator); +} + template void ColumnArray::countSerializeByteSizeImpl( PaddedPODArray & byte_size, diff --git a/dbms/src/Columns/ColumnArray.h b/dbms/src/Columns/ColumnArray.h index ffdb00d302c..1453b2c2d0b 100644 --- a/dbms/src/Columns/ColumnArray.h +++ b/dbms/src/Columns/ColumnArray.h @@ -82,6 +82,8 @@ class ColumnArray final : public COWPtrHelper String &) const override; const char * deserializeAndInsertFromArena(const char * pos, const TiDB::TiDBCollatorPtr &) override; + size_t serializeByteSize() const override; + void countSerializeByteSize(PaddedPODArray & byte_size) const override; void countSerializeByteSizeForCmp( PaddedPODArray & byte_size, diff --git a/dbms/src/Columns/ColumnConst.h b/dbms/src/Columns/ColumnConst.h index 1baed2636f6..7548e07b2a8 100644 --- a/dbms/src/Columns/ColumnConst.h +++ b/dbms/src/Columns/ColumnConst.h @@ -109,6 +109,11 @@ class ColumnConst final : public COWPtrHelper return res; } + size_t serializeByteSize() const override + { + throw Exception("Method serializeByteSize is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + } + void countSerializeByteSize(PaddedPODArray & /* byte_size */) const override { throw Exception("Method countSerializeByteSize is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); diff --git a/dbms/src/Columns/ColumnDecimal.h b/dbms/src/Columns/ColumnDecimal.h index 6a1394368ba..43f1f869a3a 100644 --- a/dbms/src/Columns/ColumnDecimal.h +++ b/dbms/src/Columns/ColumnDecimal.h @@ -152,6 +152,8 @@ class ColumnDecimal final : public COWPtrHelper & byte_size) const override; void countSerializeByteSizeForCmp( PaddedPODArray & byte_size, diff --git a/dbms/src/Columns/ColumnFixedString.h b/dbms/src/Columns/ColumnFixedString.h index 5b28a7097ef..7193a4f3b6f 100644 --- a/dbms/src/Columns/ColumnFixedString.h +++ b/dbms/src/Columns/ColumnFixedString.h @@ -105,6 +105,8 @@ class ColumnFixedString final : public COWPtrHelper const char * deserializeAndInsertFromArena(const char * pos, const TiDB::TiDBCollatorPtr &) override; + size_t serializeByteSize() const override { return chars.size(); } + void countSerializeByteSize(PaddedPODArray & byte_size) const override; void countSerializeByteSizeForCmp( PaddedPODArray & byte_size, diff --git a/dbms/src/Columns/ColumnFunction.h b/dbms/src/Columns/ColumnFunction.h index 9725a2d43e6..229e0ce9ce6 100644 --- a/dbms/src/Columns/ColumnFunction.h +++ b/dbms/src/Columns/ColumnFunction.h @@ -120,6 +120,11 @@ class ColumnFunction final : public COWPtrHelper throw Exception("Cannot deserialize to " + getName(), ErrorCodes::NOT_IMPLEMENTED); } + size_t serializeByteSize() const override + { + throw Exception("Method serializeByteSize is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + } + void countSerializeByteSize(PaddedPODArray & /* byte_size */) const override { throw Exception("Method countSerializeByteSize is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); diff --git a/dbms/src/Columns/ColumnNullable.cpp b/dbms/src/Columns/ColumnNullable.cpp index 2a5403a1278..8a5ceaa1b0c 100644 --- a/dbms/src/Columns/ColumnNullable.cpp +++ b/dbms/src/Columns/ColumnNullable.cpp @@ -282,6 +282,11 @@ const char * ColumnNullable::deserializeAndInsertFromArena(const char * pos, con return pos; } +size_t ColumnNullable::serializeByteSize() const +{ + return getNestedColumn().serializeByteSize() + getNullMapColumn().serializeByteSize(); +} + void ColumnNullable::countSerializeByteSize(PaddedPODArray & byte_size) const { getNullMapColumn().countSerializeByteSize(byte_size); diff --git a/dbms/src/Columns/ColumnNullable.h b/dbms/src/Columns/ColumnNullable.h index a8389ed4efc..ed1fd3f8439 100644 --- a/dbms/src/Columns/ColumnNullable.h +++ b/dbms/src/Columns/ColumnNullable.h @@ -80,6 +80,8 @@ class ColumnNullable final : public COWPtrHelper String &) const override; const char * deserializeAndInsertFromArena(const char * pos, const TiDB::TiDBCollatorPtr &) override; + size_t serializeByteSize() const override; + void countSerializeByteSize(PaddedPODArray & byte_size) const override; void countSerializeByteSizeForCmp( PaddedPODArray & byte_size, diff --git a/dbms/src/Columns/ColumnString.h b/dbms/src/Columns/ColumnString.h index 5e7b5b00a10..a1f481f847a 100644 --- a/dbms/src/Columns/ColumnString.h +++ b/dbms/src/Columns/ColumnString.h @@ -238,6 +238,8 @@ class ColumnString final : public COWPtrHelper return pos + string_size; } + size_t serializeByteSize() const override { return chars.size() + offsets.size() * sizeof(UInt32); } + void countSerializeByteSize(PaddedPODArray & byte_size) const override; void countSerializeByteSizeForCmp( PaddedPODArray & byte_size, diff --git a/dbms/src/Columns/ColumnTuple.h b/dbms/src/Columns/ColumnTuple.h index 315d96b23cf..6ae50cbd283 100644 --- a/dbms/src/Columns/ColumnTuple.h +++ b/dbms/src/Columns/ColumnTuple.h @@ -97,6 +97,14 @@ class ColumnTuple final : public COWPtrHelper String &) const override; const char * deserializeAndInsertFromArena(const char * pos, const TiDB::TiDBCollatorPtr &) override; + size_t serializeByteSize() const override + { + size_t res = 0; + for (const auto & column : columns) + res += column->serializeByteSize(); + return res; + } + void countSerializeByteSize(PaddedPODArray & byte_size) const override { for (const auto & column : columns) diff --git a/dbms/src/Columns/ColumnVector.h b/dbms/src/Columns/ColumnVector.h index 1c418f7b75f..067b696cdfe 100644 --- a/dbms/src/Columns/ColumnVector.h +++ b/dbms/src/Columns/ColumnVector.h @@ -321,6 +321,8 @@ class ColumnVector final : public COWPtrHelper & byte_size) const override; void countSerializeByteSizeForCmp( PaddedPODArray & byte_size, diff --git a/dbms/src/Columns/IColumn.h b/dbms/src/Columns/IColumn.h index a302bc28606..6ac8bfe4f75 100644 --- a/dbms/src/Columns/IColumn.h +++ b/dbms/src/Columns/IColumn.h @@ -253,6 +253,10 @@ class IColumn : public COWPtr virtual const char * deserializeAndInsertFromArena(const char * pos, const TiDB::TiDBCollatorPtr & collator) = 0; const char * deserializeAndInsertFromArena(const char * pos) { return deserializeAndInsertFromArena(pos, nullptr); } + /// Size of serialized column data in memory (may be approximate). + /// The main difference between `serializeByteSize` and `byteSize` is that `serializeByteSize` uses UInt32 + /// instead of size_t to represent the length of each element. + virtual size_t serializeByteSize() const = 0; /// Count the serialize byte size and added to the byte_size. /// The byte_size.size() must be equal to the column size. virtual void countSerializeByteSize(PaddedPODArray & /* byte_size */) const = 0; diff --git a/dbms/src/Columns/IColumnDummy.h b/dbms/src/Columns/IColumnDummy.h index 7900184c03e..86b3e6c00bd 100644 --- a/dbms/src/Columns/IColumnDummy.h +++ b/dbms/src/Columns/IColumnDummy.h @@ -88,6 +88,11 @@ class IColumnDummy : public IColumn return pos; } + size_t serializeByteSize() const override + { + throw Exception("Method serializeByteSize is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + } + void countSerializeByteSize(PaddedPODArray & /* byte_size */) const override { throw Exception("Method countSerializeByteSize is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index d4539857e96..22698e78aeb 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -452,7 +452,7 @@ void HashJoin::workAfterBuildRowFinish() settings.probe_enable_prefetch_threshold, enable_tagged_pointer); - const size_t lm_size_threshold = 32; + const size_t lm_size_threshold = 16; bool late_materialization = false; size_t avg_lm_row_size = 0; if (has_other_condition diff --git a/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp b/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp index bfbb240314b..337b24e0500 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp @@ -62,14 +62,7 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( { size_t index = row_layout.other_column_indexes[i].first; const auto & column = block.getByPosition(index).column; - if (const auto * column_string = typeid_cast(column.get())) - { - wd.lm_row_size += column_string->getChars().size() + sizeof(UInt32) * column_string->size(); - } - else - { - wd.lm_row_size += column->byteSize(); - } + wd.lm_row_size += column->serializeByteSize(); } } @@ -142,7 +135,7 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( container.data.resize(wd.partition_row_sizes[i], CPU_CACHE_LINE_SIZE); wd.enable_tagged_pointer &= isRowPtrTagZero(container.data.data()); wd.enable_tagged_pointer &= isRowPtrTagZero(container.data.data() + wd.partition_row_sizes[i]); - assert((reinterpret_cast(container.data.data()) & (CPU_CACHE_LINE_SIZE - 1)) == 0); + RUNTIME_CHECK((reinterpret_cast(container.data.data()) & (CPU_CACHE_LINE_SIZE - 1)) == 0); wd.all_size += wd.partition_row_sizes[i]; container.offsets.reserve(wd.partition_row_count[i]); From b70e60baec61cfeb7396c210e34b26333835bc8c Mon Sep 17 00:00:00 2001 From: gengliqi Date: Wed, 9 Apr 2025 16:36:15 +0800 Subject: [PATCH 18/84] add unit tests for IColumn::serializeByteSize Signed-off-by: gengliqi --- dbms/src/Columns/tests/gtest_column_misc.cpp | 43 +++++++++++++++++++ .../JoinV2/HashJoinPointerTable.cpp | 2 +- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/dbms/src/Columns/tests/gtest_column_misc.cpp b/dbms/src/Columns/tests/gtest_column_misc.cpp index d6961b04326..37aa222c8b6 100644 --- a/dbms/src/Columns/tests/gtest_column_misc.cpp +++ b/dbms/src/Columns/tests/gtest_column_misc.cpp @@ -16,6 +16,8 @@ #include #include +#include "common/types.h" + namespace DB { namespace tests @@ -68,6 +70,47 @@ try } CATCH +TEST_F(TestColumnMisc, TestSerializeByteSize) +try +{ + auto col_vector = createColumn({1, 2, 3}).column; + ASSERT_EQ(col_vector->serializeByteSize(), sizeof(UInt32) * 3); + auto col_decimal = createColumn(std::make_tuple(10, 3), {"1234567.333", "23333.99"}).column; + ASSERT_EQ(col_decimal->serializeByteSize(), sizeof(Decimal128) * 2); + auto col_string = createColumn({"abc", "def", "g", "hij", "", "mn"}).column; + ASSERT_EQ(col_string->serializeByteSize(), sizeof(UInt32) * 6 + 18); + + auto nullable_col_vector = toNullableVec({1, 2, 3, 4, {}}).column; + ASSERT_EQ(nullable_col_vector->serializeByteSize(), sizeof(UInt8) * 5 + sizeof(UInt64) * 5); + auto nullable_col_decimal = createNullableColumn( + std::make_tuple(65, 30), + { + "123456789012345678901234567890", + "100.1111111111", + "-11111111111111111111", + "0.1111111111111", + "0.1111111111111", + "2.2222222222", + }, + {1, 0, 1, 1, 0, 1}) + .column; + ASSERT_EQ(nullable_col_decimal->serializeByteSize(), sizeof(UInt8) * 6 + sizeof(Decimal256) * 6); + auto nullable_col_string = toNullableVec({"123456789", {}, "1"}).column; + ASSERT_EQ(nullable_col_string->serializeByteSize(), sizeof(UInt8) * 3 + sizeof(UInt32) * 3 + 13); + + auto col_array = createColumn( + std::make_tuple(std::make_shared()), // + {Array{}, Array{1.0, 2.0}, Array{1.0, 2.0, 3.0}}) + .column; + ASSERT_EQ(col_array->serializeByteSize(), sizeof(UInt32) * 3 + sizeof(Float32) * 5); + + ColumnPtr col_fixed_string = ColumnFixedString::create(3); + col_fixed_string->assumeMutable()->insertData("123", 2); + col_fixed_string->assumeMutable()->insertData("12", 2); + col_fixed_string->assumeMutable()->insertData("1", 1); + ASSERT_EQ(col_fixed_string->serializeByteSize(), 3 * 3); +} +CATCH } // namespace tests } // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp b/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp index ac690087eb6..f57751a954e 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp @@ -49,7 +49,7 @@ void HashJoinPointerTable::init( } RUNTIME_CHECK(isPowerOfTwo(pointer_table_size) && pointer_table_size > 0); - pointer_table_size_degree = __builtin_ctzll(pointer_table_size); + pointer_table_size_degree = std::countr_zero(pointer_table_size); RUNTIME_CHECK((1ULL << pointer_table_size_degree) == pointer_table_size); enable_probe_prefetch = pointer_table_size >= probe_prefetch_threshold; From b1738d5a761bb2c8f0e648c35ac1aa62c823d82f Mon Sep 17 00:00:00 2001 From: gengliqi Date: Wed, 9 Apr 2025 16:39:41 +0800 Subject: [PATCH 19/84] u Signed-off-by: gengliqi --- dbms/src/Columns/tests/gtest_column_misc.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dbms/src/Columns/tests/gtest_column_misc.cpp b/dbms/src/Columns/tests/gtest_column_misc.cpp index 37aa222c8b6..4d19c864baf 100644 --- a/dbms/src/Columns/tests/gtest_column_misc.cpp +++ b/dbms/src/Columns/tests/gtest_column_misc.cpp @@ -15,8 +15,7 @@ #include #include #include - -#include "common/types.h" +#include namespace DB { @@ -60,7 +59,7 @@ try auto col_string = createColumn({"sdafyuwer123"}).column; testCloneFullColumn(col_string); auto col_array = createColumn( - std::make_tuple(std::make_shared()), // + std::make_tuple(std::make_shared()), {Array{}, Array{1.0, 2.0}, Array{1.0, 2.0, 3.0}}) .column; testCloneFullColumn(col_array); @@ -99,7 +98,7 @@ try ASSERT_EQ(nullable_col_string->serializeByteSize(), sizeof(UInt8) * 3 + sizeof(UInt32) * 3 + 13); auto col_array = createColumn( - std::make_tuple(std::make_shared()), // + std::make_tuple(std::make_shared()), {Array{}, Array{1.0, 2.0}, Array{1.0, 2.0, 3.0}}) .column; ASSERT_EQ(col_array->serializeByteSize(), sizeof(UInt32) * 3 + sizeof(Float32) * 5); From 43b8e0a299554b6f13a2df64974aad3dd35945e5 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 10 Apr 2025 16:08:28 +0800 Subject: [PATCH 20/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 3d82d1b266c..3ca77cf06d3 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -80,27 +80,29 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData IColumn::Offsets selective_offsets; // For left outer join with no other condition IColumn::Offsets not_matched_selective_offsets; - RowPtrs row_ptrs_for_lm; RowPtrs insert_batch; - size_t probe_handle_rows = 0; - size_t probe_time = 0; - size_t probe_hash_table_time = 0; - size_t replicate_time = 0; - size_t other_condition_time = 0; - size_t collision = 0; - /// For other condition ColumnVector::Container filter; IColumn::Offsets filter_offsets; IColumn::Offsets filter_selective_offsets; + /// For late materialization + RowPtrs row_ptrs_for_lm; RowPtrs filter_row_ptrs_for_lm; /// Schema: HashJoin::all_sample_block_pruned Block result_block; /// Schema: HashJoin::output_block_after_finalize Block result_block_for_other_condition; + + /// Metrics + size_t probe_handle_rows = 0; + size_t probe_time = 0; + size_t probe_hash_table_time = 0; + size_t replicate_time = 0; + size_t other_condition_time = 0; + size_t collision = 0; }; /// The implemtation of prefetching in join probe process is inspired by a paper named From 62ca1fbd4a84c786c52dc4874d390190b175605b Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 10 Apr 2025 19:51:37 +0800 Subject: [PATCH 21/84] support (left outer) (anti) semi join with no other condition Signed-off-by: gengliqi --- dbms/src/Common/ColumnNTAlignBuffer.h | 2 +- dbms/src/Core/Types.h | 2 +- .../Flash/Planner/Plans/PhysicalJoinV2.cpp | 23 +- dbms/src/Flash/tests/gtest_join_executor.cpp | 125 +++++--- dbms/src/Interpreters/JoinV2/HashJoin.cpp | 32 +- dbms/src/Interpreters/JoinV2/HashJoin.h | 2 + .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 284 +++++++++++++++--- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 4 +- 8 files changed, 372 insertions(+), 102 deletions(-) diff --git a/dbms/src/Common/ColumnNTAlignBuffer.h b/dbms/src/Common/ColumnNTAlignBuffer.h index d7c2dacdd3c..74d35fa0caa 100644 --- a/dbms/src/Common/ColumnNTAlignBuffer.h +++ b/dbms/src/Common/ColumnNTAlignBuffer.h @@ -28,7 +28,7 @@ static constexpr size_t VECTOR_SIZE_AVX2 = sizeof(__m256i); static constexpr size_t FULL_VECTOR_SIZE_AVX2 = 2 * VECTOR_SIZE_AVX2; static_assert(FULL_VECTOR_SIZE_AVX2 == 64); -/// NT stand for non-temporal. +/// NT stands for non-temporal. union alignas(FULL_VECTOR_SIZE_AVX2) NTAlignBufferAVX2 { char data[FULL_VECTOR_SIZE_AVX2]{}; diff --git a/dbms/src/Core/Types.h b/dbms/src/Core/Types.h index 6cc71cbd9cc..6ebe073d9fc 100644 --- a/dbms/src/Core/Types.h +++ b/dbms/src/Core/Types.h @@ -225,7 +225,7 @@ struct TypeId static constexpr const TypeIndex value = TypeIndex::Float64; }; -/// Avoid to use `std::vector` using BoolVec = std::vector; } // namespace DB diff --git a/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp b/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp index fbe5a1b305a..2359be8da8c 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp @@ -210,15 +210,20 @@ bool PhysicalJoinV2::isSupported(const tipb::Join & join) { case Inner: case LeftOuter: - //case Semi: - //case Anti: - //case RightOuter: - //case RightSemi: - //case RightAnti: - { - if (!tiflash_join.getBuildJoinKeys().empty()) - return true; - } + if (!tiflash_join.getBuildJoinKeys().empty()) + return true; + break; + case Semi: + case Anti: + case LeftOuterSemi: + case LeftOuterAnti: + if (!tiflash_join.getBuildJoinKeys().empty() && join.other_conditions_size() == 0 + && join.other_eq_conditions_from_in_size() == 0) + return true; + break; + //case RightOuter: + //case RightSemi: + //case RightAnti: default: } return false; diff --git a/dbms/src/Flash/tests/gtest_join_executor.cpp b/dbms/src/Flash/tests/gtest_join_executor.cpp index d26e04e8e65..aedbdca13c0 100644 --- a/dbms/src/Flash/tests/gtest_join_executor.cpp +++ b/dbms/src/Flash/tests/gtest_join_executor.cpp @@ -4300,6 +4300,7 @@ try context.addMockTable("semi", "t", {{"a", TiDB::TP::TypeLong, false}}, left); context.addMockTable("semi", "s", {{"a", TiDB::TP::TypeLong, false}}, right); + WRAP_FOR_JOIN_TEST_BEGIN for (const auto type : {JoinType::TypeLeftOuterSemiJoin, JoinType::TypeAntiLeftOuterSemiJoin, @@ -4308,19 +4309,27 @@ try { auto reference = genSemiJoinResult(type, left, res); auto request = context.scan("semi", "t").join(context.scan("semi", "s"), type, {col("a")}).build(context); - for (auto need_force_semi_join_time_exceed : semi_join_time_exceed) + if (cfg.enable_join_v2) { - if (need_force_semi_join_time_exceed) - { - FailPointHelper::enableFailPoint(FailPoints::force_semi_join_time_exceed); - } - else + executeAndAssertColumnsEqual(request, reference); + } + else + { + for (auto need_force_semi_join_time_exceed : semi_join_time_exceed) { - FailPointHelper::disableFailPoint(FailPoints::force_semi_join_time_exceed); + if (need_force_semi_join_time_exceed) + { + FailPointHelper::enableFailPoint(FailPoints::force_semi_join_time_exceed); + } + else + { + FailPointHelper::disableFailPoint(FailPoints::force_semi_join_time_exceed); + } + executeAndAssertColumnsEqual(request, reference); } - executeAndAssertColumnsEqual(request, reference); } } + WRAP_FOR_JOIN_TEST_END } /// One join key(t.a = s.a) + other condition(t.c < s.c). @@ -4352,6 +4361,7 @@ try context.addMockTable("semi", "t", {{"a", TiDB::TP::TypeLong, false}, {"c", TiDB::TP::TypeLong}}, left); context.addMockTable("semi", "s", {{"a", TiDB::TP::TypeLong, false}, {"c", TiDB::TP::TypeLong}}, right); + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_BEGIN for (const auto type : {JoinType::TypeLeftOuterSemiJoin, JoinType::TypeAntiLeftOuterSemiJoin, @@ -4363,19 +4373,27 @@ try = context.scan("semi", "t") .join(context.scan("semi", "s"), type, {col("a")}, {}, {}, {lt(col("t.c"), col("s.c"))}, {}) .build(context); - for (auto need_force_semi_join_time_exceed : semi_join_time_exceed) + if (cfg.enable_join_v2) { - if (need_force_semi_join_time_exceed) - { - FailPointHelper::enableFailPoint(FailPoints::force_semi_join_time_exceed); - } - else + executeAndAssertColumnsEqual(request, reference); + } + else + { + for (auto need_force_semi_join_time_exceed : semi_join_time_exceed) { - FailPointHelper::disableFailPoint(FailPoints::force_semi_join_time_exceed); + if (need_force_semi_join_time_exceed) + { + FailPointHelper::enableFailPoint(FailPoints::force_semi_join_time_exceed); + } + else + { + FailPointHelper::disableFailPoint(FailPoints::force_semi_join_time_exceed); + } + executeAndAssertColumnsEqual(request, reference); } - executeAndAssertColumnsEqual(request, reference); } } + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END } /// Two join keys(t.a = s.a and t.b = s.b) + no other condition. @@ -4408,6 +4426,7 @@ try context.addMockTable("semi", "t", {{"a", TiDB::TP::TypeLong, false}, {"b", TiDB::TP::TypeLong, false}}, left); context.addMockTable("semi", "s", {{"a", TiDB::TP::TypeLong, false}, {"b", TiDB::TP::TypeLong, false}}, right); + WRAP_FOR_JOIN_TEST_BEGIN for (const auto type : {JoinType::TypeLeftOuterSemiJoin, JoinType::TypeAntiLeftOuterSemiJoin, @@ -4418,19 +4437,27 @@ try auto request = context.scan("semi", "t") .join(context.scan("semi", "s"), type, {col("a"), col("b")}, {}) .build(context); - for (auto need_force_semi_join_time_exceed : semi_join_time_exceed) + if (cfg.enable_join_v2) { - if (need_force_semi_join_time_exceed) - { - FailPointHelper::enableFailPoint(FailPoints::force_semi_join_time_exceed); - } - else + executeAndAssertColumnsEqual(request, reference); + } + else + { + for (auto need_force_semi_join_time_exceed : semi_join_time_exceed) { - FailPointHelper::disableFailPoint(FailPoints::force_semi_join_time_exceed); + if (need_force_semi_join_time_exceed) + { + FailPointHelper::enableFailPoint(FailPoints::force_semi_join_time_exceed); + } + else + { + FailPointHelper::disableFailPoint(FailPoints::force_semi_join_time_exceed); + } + executeAndAssertColumnsEqual(request, reference); } - executeAndAssertColumnsEqual(request, reference); } } + WRAP_FOR_JOIN_TEST_END } /// Two join keys(t.a = s.a and t.b = s.b) + other condition(t.c < s.c). @@ -4494,6 +4521,7 @@ try {{"a", TiDB::TP::TypeLong, false}, {"b", TiDB::TP::TypeLong, false}, {"c", TiDB::TP::TypeLong}}, right); + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_BEGIN for (const auto type : {JoinType::TypeLeftOuterSemiJoin, JoinType::TypeAntiLeftOuterSemiJoin, @@ -4511,19 +4539,27 @@ try {lt(col("t.c"), col("s.c"))}, {}) .build(context); - for (auto need_force_semi_join_time_exceed : semi_join_time_exceed) + if (cfg.enable_join_v2) { - if (need_force_semi_join_time_exceed) - { - FailPointHelper::enableFailPoint(FailPoints::force_semi_join_time_exceed); - } - else + executeAndAssertColumnsEqual(request, reference); + } + else + { + for (auto need_force_semi_join_time_exceed : semi_join_time_exceed) { - FailPointHelper::disableFailPoint(FailPoints::force_semi_join_time_exceed); + if (need_force_semi_join_time_exceed) + { + FailPointHelper::enableFailPoint(FailPoints::force_semi_join_time_exceed); + } + else + { + FailPointHelper::disableFailPoint(FailPoints::force_semi_join_time_exceed); + } + executeAndAssertColumnsEqual(request, reference); } - executeAndAssertColumnsEqual(request, reference); } } + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END } /// Two join keys(t.a = s.a and t.b = s.b) + no other condition + collation(UTF8MB4_UNICODE_CI). @@ -4544,6 +4580,7 @@ try context .addMockTable("semi", "s", {{"a", TiDB::TP::TypeString, false}, {"b", TiDB::TP::TypeString, false}}, right); + WRAP_FOR_JOIN_TEST_BEGIN for (const auto type : {JoinType::TypeLeftOuterSemiJoin, JoinType::TypeAntiLeftOuterSemiJoin, @@ -4554,19 +4591,27 @@ try auto request = context.scan("semi", "t") .join(context.scan("semi", "s"), type, {col("a"), col("b")}, {}) .build(context); - for (auto need_force_semi_join_time_exceed : semi_join_time_exceed) + if (cfg.enable_join_v2) { - if (need_force_semi_join_time_exceed) - { - FailPointHelper::enableFailPoint(FailPoints::force_semi_join_time_exceed); - } - else + executeAndAssertColumnsEqual(request, reference); + } + else + { + for (auto need_force_semi_join_time_exceed : semi_join_time_exceed) { - FailPointHelper::disableFailPoint(FailPoints::force_semi_join_time_exceed); + if (need_force_semi_join_time_exceed) + { + FailPointHelper::enableFailPoint(FailPoints::force_semi_join_time_exceed); + } + else + { + FailPointHelper::disableFailPoint(FailPoints::force_semi_join_time_exceed); + } + executeAndAssertColumnsEqual(request, reference); } - executeAndAssertColumnsEqual(request, reference); } } + WRAP_FOR_JOIN_TEST_END } } CATCH diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index 22698e78aeb..9837c473ba8 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -18,6 +18,8 @@ #include #include #include +#include +#include #include #include #include @@ -108,6 +110,8 @@ StringCollatorKind getStringCollatorKind(const TiDB::TiDBCollators & collators) } // namespace +const DataTypePtr HashJoin::match_helper_type = makeNullable(std::make_shared()); + HashJoin::HashJoin( const Names & key_names_left_, const Names & key_names_right_, @@ -364,11 +368,29 @@ void HashJoin::initProbe(const Block & sample_block, size_t probe_concurrency_) } output_column_indexes.push_back(output_index); } - RUNTIME_CHECK_MSG( - output_columns == output_block_after_finalize.columns(), - "output columns {} in all_sample_block_pruned != columns {} in output_block_after_finalize", - output_columns, - output_block_after_finalize.columns()); + if (isLeftOuterSemiFamily(kind)) + { + RUNTIME_CHECK_MSG( + output_columns + 1 == output_block_after_finalize.columns(), + "output columns {} in all_sample_block_pruned + 1 != columns {} in output_block_after_finalize", + output_columns, + output_block_after_finalize.columns()); + RUNTIME_CHECK_MSG( + output_block_after_finalize.has(match_helper_name), + "output_block_after_finalize does not have {} for join kind {}", + match_helper_name, + magic_enum::enum_name(kind)); + + RUNTIME_CHECK(output_block_after_finalize.getByName(match_helper_name).type->equals(*match_helper_type)); + } + else + { + RUNTIME_CHECK_MSG( + output_columns == output_block_after_finalize.columns(), + "output columns {} in all_sample_block_pruned != columns {} in output_block_after_finalize", + output_columns, + output_block_after_finalize.columns()); + } if (has_other_condition) { diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.h b/dbms/src/Interpreters/JoinV2/HashJoin.h index ee3eb0d239a..75676c0e2ac 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.h +++ b/dbms/src/Interpreters/JoinV2/HashJoin.h @@ -83,6 +83,8 @@ class HashJoin private: friend JoinProbeBlockHelper; + static const DataTypePtr match_helper_type; + const ASTTableJoin::Kind kind; const String join_req_id; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 1f793751d53..0e586ec95e2 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -126,7 +126,9 @@ void JoinProbeContext::prepareForHashProbe( template struct ProbeAdder { + static constexpr bool need_matched = true; static constexpr bool need_not_matched = false; + static constexpr bool break_on_first_match = false; static bool ALWAYS_INLINE addMatched( JoinProbeBlockHelper & helper, @@ -160,7 +162,9 @@ struct ProbeAdder template struct ProbeAdder { + static constexpr bool need_matched = true; static constexpr bool need_not_matched = !has_other_condition; + static constexpr bool break_on_first_match = false; static bool ALWAYS_INLINE addMatched( JoinProbeBlockHelper & helper, @@ -216,6 +220,138 @@ struct ProbeAdder } }; +template <> +struct ProbeAdder +{ + static constexpr bool need_matched = true; + static constexpr bool need_not_matched = false; + static constexpr bool break_on_first_match = true; + + static bool ALWAYS_INLINE addMatched( + JoinProbeBlockHelper & helper, + JoinProbeContext &, + JoinProbeWorkerData & wd, + MutableColumns &, + size_t idx, + size_t & current_offset, + RowPtr, + size_t) + { + ++current_offset; + wd.selective_offsets.push_back(idx); + return current_offset >= helper.settings.max_block_size; + } + + static bool ALWAYS_INLINE + addNotMatched(JoinProbeBlockHelper &, JoinProbeContext &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &) + { + return false; + } + + static void flush(JoinProbeBlockHelper &, JoinProbeWorkerData &, MutableColumns &) {} +}; + +template <> +struct ProbeAdder +{ + static constexpr bool need_matched = false; + static constexpr bool need_not_matched = true; + static constexpr bool break_on_first_match = true; + + static bool ALWAYS_INLINE addMatched( + JoinProbeBlockHelper &, + JoinProbeContext &, + JoinProbeWorkerData &, + MutableColumns &, + size_t, + size_t &, + RowPtr, + size_t) + { + return false; + } + + static bool ALWAYS_INLINE addNotMatched( + JoinProbeBlockHelper & helper, + JoinProbeContext &, + JoinProbeWorkerData & wd, + MutableColumns &, + size_t idx, + size_t & current_offset) + { + ++current_offset; + wd.selective_offsets.push_back(idx); + return current_offset >= helper.settings.max_block_size; + } + + static void flush(JoinProbeBlockHelper &, JoinProbeWorkerData &, MutableColumns &) {} +}; + +template <> +struct ProbeAdder +{ + static constexpr bool need_matched = true; + static constexpr bool need_not_matched = false; + static constexpr bool break_on_first_match = true; + + static bool ALWAYS_INLINE addMatched( + JoinProbeBlockHelper &, + JoinProbeContext &, + JoinProbeWorkerData & wd, + MutableColumns &, + size_t idx, + size_t &, + RowPtr, + size_t) + { + wd.match_helper_res[idx] = 1; + return false; + } + + static bool ALWAYS_INLINE + addNotMatched(JoinProbeBlockHelper &, JoinProbeContext &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &) + { + return false; + } + + static void flush(JoinProbeBlockHelper &, JoinProbeWorkerData &, MutableColumns &) {} +}; + +template <> +struct ProbeAdder +{ + static constexpr bool need_matched = false; + static constexpr bool need_not_matched = true; + static constexpr bool break_on_first_match = true; + + static bool ALWAYS_INLINE addMatched( + JoinProbeBlockHelper &, + JoinProbeContext &, + JoinProbeWorkerData &, + MutableColumns &, + size_t, + size_t &, + RowPtr, + size_t) + { + return false; + } + + static bool ALWAYS_INLINE addNotMatched( + JoinProbeBlockHelper &, + JoinProbeContext &, + JoinProbeWorkerData & wd, + MutableColumns &, + size_t idx, + size_t &) + { + wd.match_helper_res[idx] = 1; + return false; + } + + static void flush(JoinProbeBlockHelper &, JoinProbeWorkerData &, MutableColumns &) {} +}; + JoinProbeBlockHelper::JoinProbeBlockHelper(const HashJoin * join, bool late_materialization) : join(join) , settings(join->settings) @@ -253,22 +389,26 @@ JoinProbeBlockHelper::JoinProbeBlockHelper(const HashJoin * join, bool late_mate CALL2(KeyGetter, JoinType, false, false) \ } -#define CALL(KeyGetter) \ - { \ - auto kind = join->kind; \ - /*bool has_other_condition = join->has_other_condition;*/ \ - if (kind == Inner) \ - CALL1(KeyGetter, Inner) \ - else if (kind == LeftOuter) \ - CALL1(KeyGetter, LeftOuter) \ - /*else if (kind == Semi && !has_other_condition) \ +#define CALL(KeyGetter) \ + { \ + auto kind = join->kind; \ + bool has_other_condition = join->has_other_condition; \ + if (kind == Inner) \ + CALL1(KeyGetter, Inner) \ + else if (kind == LeftOuter) \ + CALL1(KeyGetter, LeftOuter) \ + else if (kind == Semi && !has_other_condition) \ CALL2(KeyGetter, Semi, false, false) \ else if (kind == Anti && !has_other_condition) \ - CALL2(KeyGetter, Anti, false, false)*/ \ - else \ - throw Exception( \ - fmt::format("Logical error: unknown combination of JOIN {}", magic_enum::enum_name(join->kind)), \ - ErrorCodes::LOGICAL_ERROR); \ + CALL2(KeyGetter, Anti, false, false) \ + else if (kind == LeftOuterSemi && !has_other_condition) \ + CALL2(KeyGetter, LeftOuterSemi, false, false) \ + else if (kind == LeftOuterAnti && !has_other_condition) \ + CALL2(KeyGetter, LeftOuterAnti, false, false) \ + else \ + throw Exception( \ + fmt::format("Logical error: unknown combination of JOIN {}", magic_enum::enum_name(join->kind)), \ + ErrorCodes::LOGICAL_ERROR); \ } switch (join->method) @@ -323,6 +463,11 @@ Block JoinProbeBlockHelper::probeImpl(JoinProbeContext & context, JoinProbeWorke wd.not_matched_selective_offsets.clear(); wd.not_matched_selective_offsets.reserve(settings.max_block_size); } + if constexpr ((kind == LeftOuterSemi || kind == LeftOuterAnti) && !has_other_condition) + { + wd.match_helper_res.clear(); + wd.match_helper_res.resize_fill_zero(context.rows); + } if constexpr (late_materialization) { wd.row_ptrs_for_lm.clear(); @@ -383,8 +528,35 @@ Block JoinProbeBlockHelper::probeImpl(JoinProbeContext & context, JoinProbeWorke added_columns); wd.probe_hash_table_time += watch.elapsedFromLastTime(); - if (wd.selective_offsets.empty()) - return join->output_block_after_finalize; + if constexpr (kind == Inner || kind == LeftOuter || kind == Semi || kind == Anti) + { + if (wd.selective_offsets.empty()) + return join->output_block_after_finalize; + } + + if constexpr (kind == LeftOuterSemi || kind == LeftOuterAnti) + { + Block res_block = join->output_block_after_finalize.cloneEmpty(); + size_t columns = res_block.columns(); + size_t match_helper_column_index = res_block.getPositionByName(join->match_helper_name); + for (size_t i = 0; i < columns; ++i) + { + if (i == match_helper_column_index) + continue; + res_block.getByPosition(i) = context.block.getByName(res_block.getByPosition(i).name); + } + + MutableColumnPtr match_helper_column_ptr + = res_block.getByPosition(match_helper_column_index).column->cloneEmpty(); + auto * match_helper_column = typeid_cast(match_helper_column_ptr.get()); + match_helper_column->getNullMapColumn().getData().resize_fill_zero(context.rows); + auto * match_helper_res = &typeid_cast &>(match_helper_column->getNestedColumn()).getData(); + match_helper_res->swap(wd.match_helper_res); + + res_block.getByPosition(match_helper_column_index).column = std::move(match_helper_column_ptr); + + return res_block; + } if constexpr (late_materialization) { @@ -524,18 +696,32 @@ void JoinProbeBlockHelper::probeFillColumns( if constexpr (Adder::need_not_matched) is_matched = true; - bool is_end = Adder::addMatched( - *this, - context, - wd, - added_columns, - idx, - current_offset, - ptr, - key_offset + key_getter.getRequiredKeyOffset(key2)); + if constexpr (Adder::need_matched) + { + bool is_end = Adder::addMatched( + *this, + context, + wd, + added_columns, + idx, + current_offset, + ptr, + key_offset + key_getter.getRequiredKeyOffset(key2)); - if unlikely (is_end) + if unlikely (is_end) + { + if constexpr (Adder::break_on_first_match) + ptr = nullptr; + + break; + } + } + + if constexpr (Adder::break_on_first_match) + { + ptr = nullptr; break; + } } ptr = getNextRowPtr(ptr); @@ -613,11 +799,7 @@ void JoinProbeBlockHelper::probeFillColumnsPrefetch( { RowPtr ptr = state->ptr; RowPtr next_ptr = getNextRowPtr(ptr); - if (next_ptr) - { - state->ptr = next_ptr; - PREFETCH_READ(next_ptr); - } + state->ptr = next_ptr; const auto & key2 = key_getter.deserializeJoinKey(ptr + key_offset); bool key_is_equal = joinKeyIsEqual(key_getter, state->key, key2, state->hash, ptr); @@ -627,28 +809,40 @@ void JoinProbeBlockHelper::probeFillColumnsPrefetch( if constexpr (Adder::need_not_matched) state->is_matched = true; - bool is_end = Adder::addMatched( - *this, - context, - wd, - added_columns, - state->index, - current_offset, - ptr, - key_offset + key_getter.getRequiredKeyOffset(key2)); - if unlikely (is_end) + if constexpr (Adder::need_matched) { - if (!next_ptr) + bool is_end = Adder::addMatched( + *this, + context, + wd, + added_columns, + state->index, + current_offset, + ptr, + key_offset + key_getter.getRequiredKeyOffset(key2)); + if unlikely (is_end) { - state->stage = ProbePrefetchStage::None; - --active_states; + if constexpr (Adder::break_on_first_match) + next_ptr = nullptr; + + if (!next_ptr) + { + state->stage = ProbePrefetchStage::None; + --active_states; + } + break; } - break; + } + + if constexpr (Adder::break_on_first_match) + { + next_ptr = nullptr; } } if (next_ptr) { + PREFETCH_READ(next_ptr); ++k; continue; } diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 3ca77cf06d3..760a9eef6fa 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -78,8 +78,10 @@ struct JoinProbeContext struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData { IColumn::Offsets selective_offsets; - // For left outer join with no other condition + /// For left outer join with no other condition IColumn::Offsets not_matched_selective_offsets; + /// For left outer (anti) semi join with no other condition + PaddedPODArray match_helper_res; RowPtrs insert_batch; From e959d0222e76b39ce84e5fc383a57e1db8f84db5 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 10 Apr 2025 19:53:59 +0800 Subject: [PATCH 22/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 0e586ec95e2..45ce2762aa3 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -712,7 +712,6 @@ void JoinProbeBlockHelper::probeFillColumns( { if constexpr (Adder::break_on_first_match) ptr = nullptr; - break; } } From 7364415e1ef85d3a5ee3f45a43a8ee7804c11d4b Mon Sep 17 00:00:00 2001 From: gengliqi Date: Fri, 11 Apr 2025 14:56:43 +0800 Subject: [PATCH 23/84] remove enable_pipeline for ExecutorTest::executeExecutor Signed-off-by: gengliqi --- dbms/src/TestUtils/ExecutorTestUtils.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dbms/src/TestUtils/ExecutorTestUtils.cpp b/dbms/src/TestUtils/ExecutorTestUtils.cpp index 971fda5c602..4ce7cb2553d 100644 --- a/dbms/src/TestUtils/ExecutorTestUtils.cpp +++ b/dbms/src/TestUtils/ExecutorTestUtils.cpp @@ -185,8 +185,7 @@ void ExecutorTest::executeExecutor( const std::shared_ptr & request, std::function<::testing::AssertionResult(const ColumnsWithTypeAndName &)> assert_func) { - WRAP_FOR_TEST_BEGIN - std::vector concurrencies{1, 2, 10}; + std::vector concurrencies{1, 10}; for (auto concurrency : concurrencies) { std::vector block_sizes{1, 2, 10, DEFAULT_BLOCK_SIZE}; @@ -194,10 +193,13 @@ void ExecutorTest::executeExecutor( { context.context->setSetting("max_block_size", Field(static_cast(block_size))); auto res = executeStreams(request, concurrency); - ASSERT_TRUE(assert_func(res)) << testInfoMsg(request, enable_pipeline, concurrency, block_size); + ASSERT_TRUE(assert_func(res)) << testInfoMsg( + request, + context.context->getSettingsRef().enable_resource_control, + concurrency, + block_size); } } - WRAP_FOR_TEST_END } void ExecutorTest::checkBlockSorted( From 32fb643c0ee5623cd528b6259f4f707b7ddf4262 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Fri, 11 Apr 2025 16:32:42 +0800 Subject: [PATCH 24/84] refine some code Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoin.cpp | 2 +- dbms/src/Interpreters/JoinV2/HashJoin.h | 4 +- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 251 ++++++++++++++---- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 148 +---------- 4 files changed, 210 insertions(+), 195 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index 9837c473ba8..28092ae4b63 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -492,7 +492,7 @@ void HashJoin::workAfterBuildRowFinish() } fiu_do_on(FailPoints::force_join_v2_probe_enable_lm, { late_materialization = true; }); fiu_do_on(FailPoints::force_join_v2_probe_disable_lm, { late_materialization = false; }); - join_probe_helper = std::make_unique(this, late_materialization); + join_probe_helper = std::make_unique(this, late_materialization); LOG_INFO( log, diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.h b/dbms/src/Interpreters/JoinV2/HashJoin.h index 75676c0e2ac..077341a4dbd 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.h +++ b/dbms/src/Interpreters/JoinV2/HashJoin.h @@ -81,7 +81,7 @@ class HashJoin void workAfterBuildRowFinish(); private: - friend JoinProbeBlockHelper; + friend JoinProbeHelper; static const DataTypePtr match_helper_type; @@ -150,7 +150,7 @@ class HashJoin size_t probe_concurrency = 0; std::vector probe_workers_data; std::atomic active_probe_worker = 0; - std::unique_ptr join_probe_helper; + std::unique_ptr join_probe_helper; const JoinProfileInfoPtr profile_info = std::make_shared(); diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 45ce2762aa3..e19c6768f12 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -123,15 +123,45 @@ void JoinProbeContext::prepareForHashProbe( is_prepared = true; } +/// The implemtation of prefetching in join probe process is inspired by a paper named +/// `Asynchronous Memory Access Chaining` in vldb-15. +/// Ref: https://www.vldb.org/pvldb/vol9/p252-kocberber.pdf +enum class ProbePrefetchStage : UInt8 +{ + None, + FindHeader, + FindNext, +}; + +template +struct ProbePrefetchState +{ + using KeyGetterType = typename KeyGetter::Type; + using KeyType = typename KeyGetterType::KeyType; + using HashValueType = typename KeyGetter::HashValueType; + + ProbePrefetchStage stage = ProbePrefetchStage::None; + bool is_matched = false; + UInt16 hash_tag = 0; + UInt32 index = 0; + HashValueType hash = 0; + KeyType key{}; + union + { + RowPtr ptr = nullptr; + std::atomic * pointer_ptr; + }; +}; + template -struct ProbeAdder +struct JoinProbeAdder { static constexpr bool need_matched = true; static constexpr bool need_not_matched = false; static constexpr bool break_on_first_match = false; static bool ALWAYS_INLINE addMatched( - JoinProbeBlockHelper & helper, + JoinProbeHelper & helper, JoinProbeContext &, JoinProbeWorkerData & wd, MutableColumns & added_columns, @@ -147,27 +177,27 @@ struct ProbeAdder } static bool ALWAYS_INLINE - addNotMatched(JoinProbeBlockHelper &, JoinProbeContext &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &) + addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &) { return false; } - static void flush(JoinProbeBlockHelper & helper, JoinProbeWorkerData & wd, MutableColumns & added_columns) + static void flush(JoinProbeHelper & helper, JoinProbeWorkerData & wd, MutableColumns & added_columns) { - helper.flushBatchIfNecessary(wd, added_columns); + helper.flushInsertBatch(wd, added_columns); helper.fillNullMapWithZero(added_columns); } }; template -struct ProbeAdder +struct JoinProbeAdder { static constexpr bool need_matched = true; static constexpr bool need_not_matched = !has_other_condition; static constexpr bool break_on_first_match = false; static bool ALWAYS_INLINE addMatched( - JoinProbeBlockHelper & helper, + JoinProbeHelper & helper, JoinProbeContext &, JoinProbeWorkerData & wd, MutableColumns & added_columns, @@ -183,7 +213,7 @@ struct ProbeAdder } static bool ALWAYS_INLINE addNotMatched( - JoinProbeBlockHelper & helper, + JoinProbeHelper & helper, JoinProbeContext &, JoinProbeWorkerData & wd, MutableColumns &, @@ -199,9 +229,9 @@ struct ProbeAdder return false; } - static void flush(JoinProbeBlockHelper & helper, JoinProbeWorkerData & wd, MutableColumns & added_columns) + static void flush(JoinProbeHelper & helper, JoinProbeWorkerData & wd, MutableColumns & added_columns) { - helper.flushBatchIfNecessary(wd, added_columns); + helper.flushInsertBatch(wd, added_columns); helper.fillNullMapWithZero(added_columns); if constexpr (!has_other_condition) @@ -221,14 +251,14 @@ struct ProbeAdder }; template <> -struct ProbeAdder +struct JoinProbeAdder { static constexpr bool need_matched = true; static constexpr bool need_not_matched = false; static constexpr bool break_on_first_match = true; static bool ALWAYS_INLINE addMatched( - JoinProbeBlockHelper & helper, + JoinProbeHelper & helper, JoinProbeContext &, JoinProbeWorkerData & wd, MutableColumns &, @@ -243,23 +273,23 @@ struct ProbeAdder } static bool ALWAYS_INLINE - addNotMatched(JoinProbeBlockHelper &, JoinProbeContext &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &) + addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &) { return false; } - static void flush(JoinProbeBlockHelper &, JoinProbeWorkerData &, MutableColumns &) {} + static void flush(JoinProbeHelper &, JoinProbeWorkerData &, MutableColumns &) {} }; template <> -struct ProbeAdder +struct JoinProbeAdder { static constexpr bool need_matched = false; static constexpr bool need_not_matched = true; static constexpr bool break_on_first_match = true; static bool ALWAYS_INLINE addMatched( - JoinProbeBlockHelper &, + JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, MutableColumns &, @@ -272,7 +302,7 @@ struct ProbeAdder } static bool ALWAYS_INLINE addNotMatched( - JoinProbeBlockHelper & helper, + JoinProbeHelper & helper, JoinProbeContext &, JoinProbeWorkerData & wd, MutableColumns &, @@ -284,18 +314,18 @@ struct ProbeAdder return current_offset >= helper.settings.max_block_size; } - static void flush(JoinProbeBlockHelper &, JoinProbeWorkerData &, MutableColumns &) {} + static void flush(JoinProbeHelper &, JoinProbeWorkerData &, MutableColumns &) {} }; template <> -struct ProbeAdder +struct JoinProbeAdder { static constexpr bool need_matched = true; static constexpr bool need_not_matched = false; static constexpr bool break_on_first_match = true; static bool ALWAYS_INLINE addMatched( - JoinProbeBlockHelper &, + JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData & wd, MutableColumns &, @@ -309,23 +339,23 @@ struct ProbeAdder } static bool ALWAYS_INLINE - addNotMatched(JoinProbeBlockHelper &, JoinProbeContext &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &) + addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &) { return false; } - static void flush(JoinProbeBlockHelper &, JoinProbeWorkerData &, MutableColumns &) {} + static void flush(JoinProbeHelper &, JoinProbeWorkerData &, MutableColumns &) {} }; template <> -struct ProbeAdder +struct JoinProbeAdder { static constexpr bool need_matched = false; static constexpr bool need_not_matched = true; static constexpr bool break_on_first_match = true; static bool ALWAYS_INLINE addMatched( - JoinProbeBlockHelper &, + JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, MutableColumns &, @@ -338,7 +368,7 @@ struct ProbeAdder } static bool ALWAYS_INLINE addNotMatched( - JoinProbeBlockHelper &, + JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData & wd, MutableColumns &, @@ -349,10 +379,10 @@ struct ProbeAdder return false; } - static void flush(JoinProbeBlockHelper &, JoinProbeWorkerData &, MutableColumns &) {} + static void flush(JoinProbeHelper &, JoinProbeWorkerData &, MutableColumns &) {} }; -JoinProbeBlockHelper::JoinProbeBlockHelper(const HashJoin * join, bool late_materialization) +JoinProbeHelper::JoinProbeHelper(const HashJoin * join, bool late_materialization) : join(join) , settings(join->settings) , pointer_table(join->pointer_table) @@ -361,10 +391,10 @@ JoinProbeBlockHelper::JoinProbeBlockHelper(const HashJoin * join, bool late_mate #define CALL3(KeyGetter, JoinType, has_other_condition, late_materialization, tagged_pointer) \ { \ func_ptr_has_null \ - = &JoinProbeBlockHelper:: \ + = &JoinProbeHelper:: \ probeImpl; \ func_ptr_no_null \ - = &JoinProbeBlockHelper:: \ + = &JoinProbeHelper:: \ probeImpl; \ } @@ -432,7 +462,7 @@ JoinProbeBlockHelper::JoinProbeBlockHelper(const HashJoin * join, bool late_mate #undef CALL3 } -Block JoinProbeBlockHelper::probe(JoinProbeContext & context, JoinProbeWorkerData & wd) +Block JoinProbeHelper::probe(JoinProbeContext & context, JoinProbeWorkerData & wd) { if (context.null_map) return (this->*func_ptr_has_null)(context, wd); @@ -441,7 +471,7 @@ Block JoinProbeBlockHelper::probe(JoinProbeContext & context, JoinProbeWorkerDat } JOIN_PROBE_TEMPLATE -Block JoinProbeBlockHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData & wd) +Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData & wd) { static_assert(has_other_condition || !late_materialization); @@ -536,26 +566,7 @@ Block JoinProbeBlockHelper::probeImpl(JoinProbeContext & context, JoinProbeWorke if constexpr (kind == LeftOuterSemi || kind == LeftOuterAnti) { - Block res_block = join->output_block_after_finalize.cloneEmpty(); - size_t columns = res_block.columns(); - size_t match_helper_column_index = res_block.getPositionByName(join->match_helper_name); - for (size_t i = 0; i < columns; ++i) - { - if (i == match_helper_column_index) - continue; - res_block.getByPosition(i) = context.block.getByName(res_block.getByPosition(i).name); - } - - MutableColumnPtr match_helper_column_ptr - = res_block.getByPosition(match_helper_column_index).column->cloneEmpty(); - auto * match_helper_column = typeid_cast(match_helper_column_ptr.get()); - match_helper_column->getNullMapColumn().getData().resize_fill_zero(context.rows); - auto * match_helper_res = &typeid_cast &>(match_helper_column->getNestedColumn()).getData(); - match_helper_res->swap(wd.match_helper_res); - - res_block.getByPosition(match_helper_column_index).column = std::move(match_helper_column_ptr); - - return res_block; + return genResultBlockForLeftOuterSemi(context, wd); } if constexpr (late_materialization) @@ -616,7 +627,7 @@ Block JoinProbeBlockHelper::probeImpl(JoinProbeContext & context, JoinProbeWorke } JOIN_PROBE_TEMPLATE -void JoinProbeBlockHelper::probeFillColumns( +void JoinProbeHelper::probeFillColumns( JoinProbeContext & context, JoinProbeWorkerData & wd, MutableColumns & added_columns) @@ -624,7 +635,7 @@ void JoinProbeBlockHelper::probeFillColumns( using KeyGetterType = typename KeyGetter::Type; using Hash = typename KeyGetter::Hash; using HashValueType = typename KeyGetter::HashValueType; - using Adder = ProbeAdder; + using Adder = JoinProbeAdder; auto & key_getter = *static_cast(context.key_getter.get()); size_t current_offset = wd.result_block.rows(); @@ -753,7 +764,7 @@ void JoinProbeBlockHelper::probeFillColumns( #define PREFETCH_READ(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */) JOIN_PROBE_TEMPLATE -void JoinProbeBlockHelper::probeFillColumnsPrefetch( +void JoinProbeHelper::probeFillColumnsPrefetch( JoinProbeContext & context, JoinProbeWorkerData & wd, MutableColumns & added_columns) @@ -761,7 +772,7 @@ void JoinProbeBlockHelper::probeFillColumnsPrefetch( using KeyGetterType = typename KeyGetter::Type; using Hash = typename KeyGetter::Hash; using HashValueType = typename KeyGetter::HashValueType; - using Adder = ProbeAdder; + using Adder = JoinProbeAdder; auto & key_getter = *static_cast(context.key_getter.get()); initPrefetchStates(context); @@ -945,7 +956,7 @@ void JoinProbeBlockHelper::probeFillColumnsPrefetch( #undef NOT_MATCHED } -Block JoinProbeBlockHelper::handleOtherConditions( +Block JoinProbeHelper::handleOtherConditions( JoinProbeContext & context, JoinProbeWorkerData & wd, ASTTableJoin::Kind kind, @@ -1204,7 +1215,106 @@ Block JoinProbeBlockHelper::handleOtherConditions( return output_block_after_finalize; } -Block JoinProbeBlockHelper::fillNotMatchedRowsForLeftOuter(JoinProbeContext & context, JoinProbeWorkerData & wd) +template +void JoinProbeHelper::initPrefetchStates(JoinProbeContext & context) +{ + if (!context.prefetch_states) + { + context.prefetch_states = decltype(context.prefetch_states)( + static_cast(new ProbePrefetchState[settings.probe_prefetch_step]), + [](void * ptr) { delete[] static_cast *>(ptr); }); + } +} + +template +void JoinProbeHelper::flushInsertBatch(JoinProbeWorkerData & wd, MutableColumns & added_columns) const +{ + if constexpr (late_materialization) + { + size_t idx = 0; + for (auto [_, is_nullable] : row_layout.raw_key_column_indexes) + { + IColumn * column = added_columns[idx].get(); + if (is_nullable) + column = &static_cast(*added_columns[idx]).getNestedColumn(); + column->deserializeAndInsertFromPos(wd.insert_batch, true); + ++idx; + } + for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) + added_columns[idx++]->deserializeAndInsertFromPos(wd.insert_batch, true); + + wd.row_ptrs_for_lm.insert(wd.insert_batch.begin(), wd.insert_batch.end()); + } + else + { + for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) + { + IColumn * column = added_columns[column_index].get(); + if (is_nullable) + column = &static_cast(*added_columns[column_index]).getNestedColumn(); + column->deserializeAndInsertFromPos(wd.insert_batch, true); + } + for (auto [column_index, _] : row_layout.other_column_indexes) + added_columns[column_index]->deserializeAndInsertFromPos(wd.insert_batch, true); + } + + if constexpr (last_flush) + { + if constexpr (late_materialization) + { + size_t idx = 0; + for (auto [_, is_nullable] : row_layout.raw_key_column_indexes) + { + IColumn * column = added_columns[idx].get(); + if (is_nullable) + column = &static_cast(*added_columns[idx]).getNestedColumn(); + column->flushNTAlignBuffer(); + ++idx; + } + for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) + added_columns[idx++]->flushNTAlignBuffer(); + } + else + { + for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) + { + IColumn * column = added_columns[column_index].get(); + if (is_nullable) + column = &static_cast(*added_columns[column_index]).getNestedColumn(); + column->flushNTAlignBuffer(); + } + for (auto [column_index, _] : row_layout.other_column_indexes) + added_columns[column_index]->flushNTAlignBuffer(); + } + } + + wd.insert_batch.clear(); +} + +template +void JoinProbeHelper::fillNullMapWithZero(MutableColumns & added_columns) const +{ + size_t idx = 0; + for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) + { + if (is_nullable) + { + size_t index; + if constexpr (late_materialization) + index = idx; + else + index = column_index; + auto & nullable_column = static_cast(*added_columns[index]); + size_t data_size = nullable_column.getNestedColumn().size(); + size_t nullmap_size = nullable_column.getNullMapColumn().size(); + RUNTIME_CHECK(nullmap_size <= data_size); + nullable_column.getNullMapColumn().getData().resize_fill_zero(data_size); + } + ++idx; + } +} + +Block JoinProbeHelper::fillNotMatchedRowsForLeftOuter(JoinProbeContext & context, JoinProbeWorkerData & wd) { RUNTIME_CHECK(join->kind == LeftOuter); RUNTIME_CHECK(join->has_other_condition); @@ -1301,4 +1411,31 @@ Block JoinProbeBlockHelper::fillNotMatchedRowsForLeftOuter(JoinProbeContext & co return output_block_after_finalize; } +Block JoinProbeHelper::genResultBlockForLeftOuterSemi(JoinProbeContext & context, JoinProbeWorkerData & wd) +{ + RUNTIME_CHECK(join->kind == LeftOuterSemi || join->kind == LeftOuterAnti); + RUNTIME_CHECK(!join->has_other_condition); + RUNTIME_CHECK(context.isProbeFinished()); + + Block res_block = join->output_block_after_finalize.cloneEmpty(); + size_t columns = res_block.columns(); + size_t match_helper_column_index = res_block.getPositionByName(join->match_helper_name); + for (size_t i = 0; i < columns; ++i) + { + if (i == match_helper_column_index) + continue; + res_block.getByPosition(i) = context.block.getByName(res_block.getByPosition(i).name); + } + + MutableColumnPtr match_helper_column_ptr = res_block.getByPosition(match_helper_column_index).column->cloneEmpty(); + auto * match_helper_column = typeid_cast(match_helper_column_ptr.get()); + match_helper_column->getNullMapColumn().getData().resize_fill_zero(context.rows); + auto * match_helper_res = &typeid_cast &>(match_helper_column->getNestedColumn()).getData(); + match_helper_res->swap(wd.match_helper_res); + + res_block.getByPosition(match_helper_column_index).column = std::move(match_helper_column_ptr); + + return res_block; +} + } // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 760a9eef6fa..230a65783fe 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -107,38 +107,8 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData size_t collision = 0; }; -/// The implemtation of prefetching in join probe process is inspired by a paper named -/// `Asynchronous Memory Access Chaining` in vldb-15. -/// Ref: https://www.vldb.org/pvldb/vol9/p252-kocberber.pdf -enum class ProbePrefetchStage : UInt8 -{ - None, - FindHeader, - FindNext, -}; - -template -struct ProbePrefetchState -{ - using KeyGetterType = typename KeyGetter::Type; - using KeyType = typename KeyGetterType::KeyType; - using HashValueType = typename KeyGetter::HashValueType; - - ProbePrefetchStage stage = ProbePrefetchStage::None; - bool is_matched = false; - UInt16 hash_tag = 0; - HashValueType hash = 0; - size_t index = 0; - union - { - RowPtr ptr = nullptr; - std::atomic * pointer_ptr; - }; - KeyType key{}; -}; - template -struct ProbeAdder; +struct JoinProbeAdder; #define JOIN_PROBE_TEMPLATE \ template < \ @@ -150,10 +120,10 @@ struct ProbeAdder; bool tagged_pointer> class HashJoin; -class JoinProbeBlockHelper +class JoinProbeHelper { public: - JoinProbeBlockHelper(const HashJoin * join, bool late_materialization); + JoinProbeHelper(const HashJoin * join, bool late_materialization); Block probe(JoinProbeContext & context, JoinProbeWorkerData & wd); @@ -169,15 +139,7 @@ class JoinProbeBlockHelper probeFillColumnsPrefetch(JoinProbeContext & context, JoinProbeWorkerData & wd, MutableColumns & added_columns); template - void ALWAYS_INLINE initPrefetchStates(JoinProbeContext & context) - { - if (!context.prefetch_states) - { - context.prefetch_states = decltype(context.prefetch_states)( - static_cast(new ProbePrefetchState[settings.probe_prefetch_step]), - [](void * ptr) { delete[] static_cast *>(ptr); }); - } - } + void initPrefetchStates(JoinProbeContext & context); template bool ALWAYS_INLINE joinKeyIsEqual( @@ -200,101 +162,15 @@ class JoinProbeBlockHelper void ALWAYS_INLINE insertRowToBatch(JoinProbeWorkerData & wd, MutableColumns & added_columns, RowPtr row_ptr) const { wd.insert_batch.push_back(row_ptr); - flushBatchIfNecessary(wd, added_columns); + if unlikely (wd.insert_batch.size() >= settings.probe_insert_batch_size) + flushInsertBatch(wd, added_columns); } - template - void ALWAYS_INLINE flushBatchIfNecessary(JoinProbeWorkerData & wd, MutableColumns & added_columns) const - { - if constexpr (!force) - { - if likely (wd.insert_batch.size() < settings.probe_insert_batch_size) - return; - } - if constexpr (late_materialization) - { - size_t idx = 0; - for (auto [_, is_nullable] : row_layout.raw_key_column_indexes) - { - IColumn * column = added_columns[idx].get(); - if (is_nullable) - column = &static_cast(*added_columns[idx]).getNestedColumn(); - column->deserializeAndInsertFromPos(wd.insert_batch, true); - ++idx; - } - for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) - added_columns[idx++]->deserializeAndInsertFromPos(wd.insert_batch, true); - - wd.row_ptrs_for_lm.insert(wd.insert_batch.begin(), wd.insert_batch.end()); - } - else - { - for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) - { - IColumn * column = added_columns[column_index].get(); - if (is_nullable) - column = &static_cast(*added_columns[column_index]).getNestedColumn(); - column->deserializeAndInsertFromPos(wd.insert_batch, true); - } - for (auto [column_index, _] : row_layout.other_column_indexes) - added_columns[column_index]->deserializeAndInsertFromPos(wd.insert_batch, true); - } - - if constexpr (force) - { - if constexpr (late_materialization) - { - size_t idx = 0; - for (auto [_, is_nullable] : row_layout.raw_key_column_indexes) - { - IColumn * column = added_columns[idx].get(); - if (is_nullable) - column = &static_cast(*added_columns[idx]).getNestedColumn(); - column->flushNTAlignBuffer(); - ++idx; - } - for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) - added_columns[idx++]->flushNTAlignBuffer(); - } - else - { - for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) - { - IColumn * column = added_columns[column_index].get(); - if (is_nullable) - column = &static_cast(*added_columns[column_index]).getNestedColumn(); - column->flushNTAlignBuffer(); - } - for (auto [column_index, _] : row_layout.other_column_indexes) - added_columns[column_index]->flushNTAlignBuffer(); - } - } - - wd.insert_batch.clear(); - } + template + void flushInsertBatch(JoinProbeWorkerData & wd, MutableColumns & added_columns) const; template - void ALWAYS_INLINE fillNullMapWithZero(MutableColumns & added_columns) const - { - size_t idx = 0; - for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) - { - if (is_nullable) - { - size_t index; - if constexpr (late_materialization) - index = idx; - else - index = column_index; - auto & nullable_column = static_cast(*added_columns[index]); - size_t data_size = nullable_column.getNestedColumn().size(); - size_t nullmap_size = nullable_column.getNullMapColumn().size(); - RUNTIME_CHECK(nullmap_size <= data_size); - nullable_column.getNullMapColumn().getData().resize_fill_zero(data_size); - } - ++idx; - } - } + void fillNullMapWithZero(MutableColumns & added_columns) const; Block handleOtherConditions( JoinProbeContext & context, @@ -304,11 +180,13 @@ class JoinProbeBlockHelper Block fillNotMatchedRowsForLeftOuter(JoinProbeContext & context, JoinProbeWorkerData & wd); + Block genResultBlockForLeftOuterSemi(JoinProbeContext & context, JoinProbeWorkerData & wd); + private: template - friend struct ProbeAdder; + friend struct JoinProbeAdder; - using FuncType = Block (JoinProbeBlockHelper::*)(JoinProbeContext &, JoinProbeWorkerData &); + using FuncType = Block (JoinProbeHelper::*)(JoinProbeContext &, JoinProbeWorkerData &); FuncType func_ptr_has_null = nullptr; FuncType func_ptr_no_null = nullptr; const HashJoin * join; From 1568c1a153619ec34404150b212b86eb3747baea Mon Sep 17 00:00:00 2001 From: gengliqi Date: Fri, 11 Apr 2025 19:08:09 +0800 Subject: [PATCH 25/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index e19c6768f12..520223344bd 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -588,7 +588,7 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData if constexpr (has_other_condition) { - // Always using late materialization for left side + // Always using late materialization for left side columns for (size_t i = 0; i < left_columns; ++i) { if (!join->left_required_flag_for_other_condition[i]) @@ -745,10 +745,7 @@ void JoinProbeHelper::probeFillColumns( ++idx; break; } - if constexpr (Adder::need_not_matched) - { - NOT_MATCHED(!is_matched) - } + NOT_MATCHED(!is_matched) } Adder::flush(*this, wd, added_columns); From 106634c0ebf1695bed36e208ac8d070a5200da1b Mon Sep 17 00:00:00 2001 From: gengliqi Date: Mon, 14 Apr 2025 14:01:33 +0800 Subject: [PATCH 26/84] use safeGetByPosition in some places Signed-off-by: gengliqi --- dbms/src/Flash/Mpp/HashBaseWriterHelper.cpp | 12 ++++++++---- dbms/src/Flash/Mpp/HashBaseWriterHelper.h | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/dbms/src/Flash/Mpp/HashBaseWriterHelper.cpp b/dbms/src/Flash/Mpp/HashBaseWriterHelper.cpp index 0035b528ec6..7b3c9f7a853 100644 --- a/dbms/src/Flash/Mpp/HashBaseWriterHelper.cpp +++ b/dbms/src/Flash/Mpp/HashBaseWriterHelper.cpp @@ -94,18 +94,20 @@ void computeHash( if unlikely (rows == 0) return; + RUNTIME_CHECK(collators.size() == partition_col_ids.size()); + RUNTIME_CHECK(partition_key_containers.size() == partition_col_ids.size()); hash.reset(rows); /// compute hash values for (size_t i = 0; i < partition_col_ids.size(); ++i) { - const auto & column = block.getByPosition(partition_col_ids[i]).column; + const auto & column = block.safeGetByPosition(partition_col_ids[i]).column; column->updateWeakHash32(hash, collators[i], partition_key_containers[i]); } } void computeHashSelectiveBlock( const Block & block, - const std::vector & partition_id_cols, + const std::vector & partition_col_ids, const TiDB::TiDBCollators & collators, std::vector & partition_key_containers, WeakHash32 & hash) @@ -113,10 +115,12 @@ void computeHashSelectiveBlock( RUNTIME_CHECK(block.info.selective && !block.info.selective->empty()); const auto selective_rows = block.info.selective->size(); + RUNTIME_CHECK(collators.size() == partition_col_ids.size()); + RUNTIME_CHECK(partition_key_containers.size() == partition_col_ids.size()); hash.reset(selective_rows); - for (size_t i = 0; i < partition_id_cols.size(); ++i) + for (size_t i = 0; i < partition_col_ids.size(); ++i) { - const auto & column = block.getByPosition(partition_id_cols[i]).column; + const auto & column = block.safeGetByPosition(partition_col_ids[i]).column; column->updateWeakHash32(hash, collators[i], partition_key_containers[i], *block.info.selective); } } diff --git a/dbms/src/Flash/Mpp/HashBaseWriterHelper.h b/dbms/src/Flash/Mpp/HashBaseWriterHelper.h index 7669500f24a..28738bc065e 100644 --- a/dbms/src/Flash/Mpp/HashBaseWriterHelper.h +++ b/dbms/src/Flash/Mpp/HashBaseWriterHelper.h @@ -43,7 +43,7 @@ void computeHash( void computeHashSelectiveBlock( const Block & block, - const std::vector & partition_id_cols, + const std::vector & partition_col_ids, const TiDB::TiDBCollators & collators, std::vector & partition_key_containers, WeakHash32 & hash); From 36e565bd6821ddcdcaec78b5136c8f0e81028ea2 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Mon, 14 Apr 2025 16:00:00 +0800 Subject: [PATCH 27/84] tiny refine Signed-off-by: gengliqi --- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 181 +++++++++--------- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 79 ++++---- 2 files changed, 135 insertions(+), 125 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 520223344bd..a06da07a23d 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -123,6 +123,94 @@ void JoinProbeContext::prepareForHashProbe( is_prepared = true; } +template +void JoinProbeHelperUtil::flushInsertBatch(JoinProbeWorkerData & wd, MutableColumns & added_columns) const +{ + if constexpr (late_materialization) + { + size_t idx = 0; + for (auto [_, is_nullable] : row_layout.raw_key_column_indexes) + { + IColumn * column = added_columns[idx].get(); + if (is_nullable) + column = &static_cast(*added_columns[idx]).getNestedColumn(); + column->deserializeAndInsertFromPos(wd.insert_batch, true); + ++idx; + } + for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) + added_columns[idx++]->deserializeAndInsertFromPos(wd.insert_batch, true); + + wd.row_ptrs_for_lm.insert(wd.insert_batch.begin(), wd.insert_batch.end()); + } + else + { + for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) + { + IColumn * column = added_columns[column_index].get(); + if (is_nullable) + column = &static_cast(*added_columns[column_index]).getNestedColumn(); + column->deserializeAndInsertFromPos(wd.insert_batch, true); + } + for (auto [column_index, _] : row_layout.other_column_indexes) + added_columns[column_index]->deserializeAndInsertFromPos(wd.insert_batch, true); + } + + if constexpr (last_flush) + { + if constexpr (late_materialization) + { + size_t idx = 0; + for (auto [_, is_nullable] : row_layout.raw_key_column_indexes) + { + IColumn * column = added_columns[idx].get(); + if (is_nullable) + column = &static_cast(*added_columns[idx]).getNestedColumn(); + column->flushNTAlignBuffer(); + ++idx; + } + for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) + added_columns[idx++]->flushNTAlignBuffer(); + } + else + { + for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) + { + IColumn * column = added_columns[column_index].get(); + if (is_nullable) + column = &static_cast(*added_columns[column_index]).getNestedColumn(); + column->flushNTAlignBuffer(); + } + for (auto [column_index, _] : row_layout.other_column_indexes) + added_columns[column_index]->flushNTAlignBuffer(); + } + } + + wd.insert_batch.clear(); +} + +template +void JoinProbeHelperUtil::fillNullMapWithZero(MutableColumns & added_columns) const +{ + size_t idx = 0; + for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) + { + if (is_nullable) + { + size_t index; + if constexpr (late_materialization) + index = idx; + else + index = column_index; + auto & nullable_column = static_cast(*added_columns[index]); + size_t data_size = nullable_column.getNestedColumn().size(); + size_t nullmap_size = nullable_column.getNullMapColumn().size(); + RUNTIME_CHECK(nullmap_size <= data_size); + nullable_column.getNullMapColumn().getData().resize_fill_zero(data_size); + } + ++idx; + } +} + /// The implemtation of prefetching in join probe process is inspired by a paper named /// `Asynchronous Memory Access Chaining` in vldb-15. /// Ref: https://www.vldb.org/pvldb/vol9/p252-kocberber.pdf @@ -383,10 +471,9 @@ struct JoinProbeAdder }; JoinProbeHelper::JoinProbeHelper(const HashJoin * join, bool late_materialization) - : join(join) - , settings(join->settings) + : JoinProbeHelperUtil(join->settings, join->row_layout) + , join(join) , pointer_table(join->pointer_table) - , row_layout(join->row_layout) { #define CALL3(KeyGetter, JoinType, has_other_condition, late_materialization, tagged_pointer) \ { \ @@ -1223,94 +1310,6 @@ void JoinProbeHelper::initPrefetchStates(JoinProbeContext & context) } } -template -void JoinProbeHelper::flushInsertBatch(JoinProbeWorkerData & wd, MutableColumns & added_columns) const -{ - if constexpr (late_materialization) - { - size_t idx = 0; - for (auto [_, is_nullable] : row_layout.raw_key_column_indexes) - { - IColumn * column = added_columns[idx].get(); - if (is_nullable) - column = &static_cast(*added_columns[idx]).getNestedColumn(); - column->deserializeAndInsertFromPos(wd.insert_batch, true); - ++idx; - } - for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) - added_columns[idx++]->deserializeAndInsertFromPos(wd.insert_batch, true); - - wd.row_ptrs_for_lm.insert(wd.insert_batch.begin(), wd.insert_batch.end()); - } - else - { - for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) - { - IColumn * column = added_columns[column_index].get(); - if (is_nullable) - column = &static_cast(*added_columns[column_index]).getNestedColumn(); - column->deserializeAndInsertFromPos(wd.insert_batch, true); - } - for (auto [column_index, _] : row_layout.other_column_indexes) - added_columns[column_index]->deserializeAndInsertFromPos(wd.insert_batch, true); - } - - if constexpr (last_flush) - { - if constexpr (late_materialization) - { - size_t idx = 0; - for (auto [_, is_nullable] : row_layout.raw_key_column_indexes) - { - IColumn * column = added_columns[idx].get(); - if (is_nullable) - column = &static_cast(*added_columns[idx]).getNestedColumn(); - column->flushNTAlignBuffer(); - ++idx; - } - for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) - added_columns[idx++]->flushNTAlignBuffer(); - } - else - { - for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) - { - IColumn * column = added_columns[column_index].get(); - if (is_nullable) - column = &static_cast(*added_columns[column_index]).getNestedColumn(); - column->flushNTAlignBuffer(); - } - for (auto [column_index, _] : row_layout.other_column_indexes) - added_columns[column_index]->flushNTAlignBuffer(); - } - } - - wd.insert_batch.clear(); -} - -template -void JoinProbeHelper::fillNullMapWithZero(MutableColumns & added_columns) const -{ - size_t idx = 0; - for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) - { - if (is_nullable) - { - size_t index; - if constexpr (late_materialization) - index = idx; - else - index = column_index; - auto & nullable_column = static_cast(*added_columns[index]); - size_t data_size = nullable_column.getNestedColumn().size(); - size_t nullmap_size = nullable_column.getNullMapColumn().size(); - RUNTIME_CHECK(nullmap_size <= data_size); - nullable_column.getNullMapColumn().getData().resize_fill_zero(data_size); - } - ++idx; - } -} - Block JoinProbeHelper::fillNotMatchedRowsForLeftOuter(JoinProbeContext & context, JoinProbeWorkerData & wd) { RUNTIME_CHECK(join->kind == LeftOuter); diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 230a65783fe..c6f99149c6a 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -107,6 +107,50 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData size_t collision = 0; }; +class JoinProbeHelperUtil +{ +public: + explicit JoinProbeHelperUtil(const HashJoinSettings & settings, const HashJoinRowLayout & row_layout) + : settings(settings) + , row_layout(row_layout) + {} + + template + static bool ALWAYS_INLINE joinKeyIsEqual( + KeyGetterType & key_getter, + const KeyType & key1, + const KeyType & key2, + HashValueType hash1, + RowPtr row_ptr) + { + if constexpr (KeyGetterType::joinKeyCompareHashFirst()) + { + auto hash2 = unalignedLoad(row_ptr + sizeof(RowPtr)); + if (hash1 != hash2) + return false; + } + return key_getter.joinKeyIsEqual(key1, key2); + } + + template + void ALWAYS_INLINE insertRowToBatch(JoinProbeWorkerData & wd, MutableColumns & added_columns, RowPtr row_ptr) const + { + wd.insert_batch.push_back(row_ptr); + if unlikely (wd.insert_batch.size() >= settings.probe_insert_batch_size) + flushInsertBatch(wd, added_columns); + } + + template + void flushInsertBatch(JoinProbeWorkerData & wd, MutableColumns & added_columns) const; + + template + void fillNullMapWithZero(MutableColumns & added_columns) const; + +protected: + const HashJoinSettings & settings; + const HashJoinRowLayout & row_layout; +}; + template struct JoinProbeAdder; @@ -120,7 +164,7 @@ struct JoinProbeAdder; bool tagged_pointer> class HashJoin; -class JoinProbeHelper +class JoinProbeHelper final : public JoinProbeHelperUtil { public: JoinProbeHelper(const HashJoin * join, bool late_materialization); @@ -141,37 +185,6 @@ class JoinProbeHelper template void initPrefetchStates(JoinProbeContext & context); - template - bool ALWAYS_INLINE joinKeyIsEqual( - KeyGetterType & key_getter, - const KeyType & key1, - const KeyType & key2, - HashValueType hash1, - RowPtr row_ptr) const - { - if constexpr (KeyGetterType::joinKeyCompareHashFirst()) - { - auto hash2 = unalignedLoad(row_ptr + sizeof(RowPtr)); - if (hash1 != hash2) - return false; - } - return key_getter.joinKeyIsEqual(key1, key2); - } - - template - void ALWAYS_INLINE insertRowToBatch(JoinProbeWorkerData & wd, MutableColumns & added_columns, RowPtr row_ptr) const - { - wd.insert_batch.push_back(row_ptr); - if unlikely (wd.insert_batch.size() >= settings.probe_insert_batch_size) - flushInsertBatch(wd, added_columns); - } - - template - void flushInsertBatch(JoinProbeWorkerData & wd, MutableColumns & added_columns) const; - - template - void fillNullMapWithZero(MutableColumns & added_columns) const; - Block handleOtherConditions( JoinProbeContext & context, JoinProbeWorkerData & wd, @@ -190,9 +203,7 @@ class JoinProbeHelper FuncType func_ptr_has_null = nullptr; FuncType func_ptr_no_null = nullptr; const HashJoin * join; - const HashJoinSettings & settings; const HashJoinPointerTable & pointer_table; - const HashJoinRowLayout & row_layout; }; } // namespace DB From 9227a0b5a0a0bfa516d31b9afa36b358f5360509 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Mon, 14 Apr 2025 16:16:25 +0800 Subject: [PATCH 28/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index c6f99149c6a..0900d0de6b4 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -109,7 +109,7 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData class JoinProbeHelperUtil { -public: +protected: explicit JoinProbeHelperUtil(const HashJoinSettings & settings, const HashJoinRowLayout & row_layout) : settings(settings) , row_layout(row_layout) @@ -164,7 +164,7 @@ struct JoinProbeAdder; bool tagged_pointer> class HashJoin; -class JoinProbeHelper final : public JoinProbeHelperUtil +class JoinProbeHelper : public JoinProbeHelperUtil { public: JoinProbeHelper(const HashJoin * join, bool late_materialization); From a3ed1e27f67ab104a3688c170defe89cb53b083e Mon Sep 17 00:00:00 2001 From: gengliqi Date: Mon, 14 Apr 2025 19:00:43 +0800 Subject: [PATCH 29/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 8 ++++---- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index a06da07a23d..37c3b455e7a 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -87,7 +87,7 @@ void JoinProbeContext::prepareForHashProbe( /// reuse null_map to record the filtered rows, the rows contains NULL or does not /// match the join filter won't join to anything recordFilteredRows(block, filter_column, null_map_holder, null_map); - /// Some useless columns maybe key columns and filter column so they must be removed after extracting key columns and filter column. + /// Some useless columns maybe key columns and filter column so they must be removed after extracting. for (size_t pos = 0; pos < block.columns();) { if (!probe_output_name_set.contains(block.getByPosition(pos).name)) @@ -557,7 +557,7 @@ Block JoinProbeHelper::probe(JoinProbeContext & context, JoinProbeWorkerData & w return (this->*func_ptr_no_null)(context, wd); } -JOIN_PROBE_TEMPLATE +JOIN_PROBE_HELPER_TEMPLATE Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData & wd) { static_assert(has_other_condition || !late_materialization); @@ -713,7 +713,7 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData return join->output_block_after_finalize; } -JOIN_PROBE_TEMPLATE +JOIN_PROBE_HELPER_TEMPLATE void JoinProbeHelper::probeFillColumns( JoinProbeContext & context, JoinProbeWorkerData & wd, @@ -847,7 +847,7 @@ void JoinProbeHelper::probeFillColumns( #define PREFETCH_READ(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */) -JOIN_PROBE_TEMPLATE +JOIN_PROBE_HELPER_TEMPLATE void JoinProbeHelper::probeFillColumnsPrefetch( JoinProbeContext & context, JoinProbeWorkerData & wd, diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 0900d0de6b4..0cf3469172b 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -154,7 +154,7 @@ class JoinProbeHelperUtil template struct JoinProbeAdder; -#define JOIN_PROBE_TEMPLATE \ +#define JOIN_PROBE_HELPER_TEMPLATE \ template < \ typename KeyGetter, \ ASTTableJoin::Kind kind, \ @@ -172,13 +172,13 @@ class JoinProbeHelper : public JoinProbeHelperUtil Block probe(JoinProbeContext & context, JoinProbeWorkerData & wd); private: - JOIN_PROBE_TEMPLATE + JOIN_PROBE_HELPER_TEMPLATE Block probeImpl(JoinProbeContext & context, JoinProbeWorkerData & wd); - JOIN_PROBE_TEMPLATE + JOIN_PROBE_HELPER_TEMPLATE void NO_INLINE probeFillColumns(JoinProbeContext & context, JoinProbeWorkerData & wd, MutableColumns & added_columns); - JOIN_PROBE_TEMPLATE + JOIN_PROBE_HELPER_TEMPLATE void NO_INLINE probeFillColumnsPrefetch(JoinProbeContext & context, JoinProbeWorkerData & wd, MutableColumns & added_columns); From 994e8a61e0d7830f3d3252c5a46eacaed8c6fbf3 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 15 Apr 2025 16:05:23 +0800 Subject: [PATCH 30/84] u Signed-off-by: gengliqi --- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 98 +++++++++---------- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 3 - 2 files changed, 45 insertions(+), 56 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 37c3b455e7a..16113b06d30 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -211,36 +211,6 @@ void JoinProbeHelperUtil::fillNullMapWithZero(MutableColumns & added_columns) co } } -/// The implemtation of prefetching in join probe process is inspired by a paper named -/// `Asynchronous Memory Access Chaining` in vldb-15. -/// Ref: https://www.vldb.org/pvldb/vol9/p252-kocberber.pdf -enum class ProbePrefetchStage : UInt8 -{ - None, - FindHeader, - FindNext, -}; - -template -struct ProbePrefetchState -{ - using KeyGetterType = typename KeyGetter::Type; - using KeyType = typename KeyGetterType::KeyType; - using HashValueType = typename KeyGetter::HashValueType; - - ProbePrefetchStage stage = ProbePrefetchStage::None; - bool is_matched = false; - UInt16 hash_tag = 0; - UInt32 index = 0; - HashValueType hash = 0; - KeyType key{}; - union - { - RowPtr ptr = nullptr; - std::atomic * pointer_ptr; - }; -}; - template struct JoinProbeAdder { @@ -753,10 +723,13 @@ void JoinProbeHelper::probeFillColumns( for (; idx < context.rows; ++idx) { - if (has_null_map && (*context.null_map)[idx]) + if constexpr (has_null_map) { - NOT_MATCHED(true) - continue; + if ((*context.null_map)[idx]) + { + NOT_MATCHED(true) + continue; + } } const auto & key = key_getter.getJoinKey(idx); @@ -845,6 +818,36 @@ void JoinProbeHelper::probeFillColumns( #undef NOT_MATCHED } +/// The implemtation of prefetching in join probe process is inspired by a paper named +/// `Asynchronous Memory Access Chaining` in vldb-15. +/// Ref: https://www.vldb.org/pvldb/vol9/p252-kocberber.pdf +enum class ProbePrefetchStage : UInt8 +{ + None, + FindHeader, + FindNext, +}; + +template +struct ProbePrefetchState +{ + using KeyGetterType = typename KeyGetter::Type; + using KeyType = typename KeyGetterType::KeyType; + using HashValueType = typename KeyGetter::HashValueType; + + ProbePrefetchStage stage = ProbePrefetchStage::None; + bool is_matched = false; + UInt16 hash_tag = 0; + UInt32 index = 0; + HashValueType hash = 0; + KeyType key{}; + union + { + RowPtr ptr = nullptr; + std::atomic * pointer_ptr; + }; +}; + #define PREFETCH_READ(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */) JOIN_PROBE_HELPER_TEMPLATE @@ -859,7 +862,12 @@ void JoinProbeHelper::probeFillColumnsPrefetch( using Adder = JoinProbeAdder; auto & key_getter = *static_cast(context.key_getter.get()); - initPrefetchStates(context); + if (!context.prefetch_states) + { + context.prefetch_states = decltype(context.prefetch_states)( + static_cast(new ProbePrefetchState[settings.probe_prefetch_step]), + [](void * ptr) { delete[] static_cast *>(ptr); }); + } auto * states = static_cast *>(context.prefetch_states.get()); size_t idx = context.start_row_idx; @@ -1299,17 +1307,6 @@ Block JoinProbeHelper::handleOtherConditions( return output_block_after_finalize; } -template -void JoinProbeHelper::initPrefetchStates(JoinProbeContext & context) -{ - if (!context.prefetch_states) - { - context.prefetch_states = decltype(context.prefetch_states)( - static_cast(new ProbePrefetchState[settings.probe_prefetch_step]), - [](void * ptr) { delete[] static_cast *>(ptr); }); - } -} - Block JoinProbeHelper::fillNotMatchedRowsForLeftOuter(JoinProbeContext & context, JoinProbeWorkerData & wd) { RUNTIME_CHECK(join->kind == LeftOuter); @@ -1397,14 +1394,9 @@ Block JoinProbeHelper::fillNotMatchedRowsForLeftOuter(JoinProbeContext & context context.rows_not_matched.clear(); } - if (result_size >= remaining_insert_size) - { - Block res_block; - res_block.swap(wd.result_block_for_other_condition); - return res_block; - } - - return output_block_after_finalize; + Block res_block; + res_block.swap(wd.result_block_for_other_condition); + return res_block; } Block JoinProbeHelper::genResultBlockForLeftOuterSemi(JoinProbeContext & context, JoinProbeWorkerData & wd) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 0cf3469172b..2747271b75d 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -182,9 +182,6 @@ class JoinProbeHelper : public JoinProbeHelperUtil void NO_INLINE probeFillColumnsPrefetch(JoinProbeContext & context, JoinProbeWorkerData & wd, MutableColumns & added_columns); - template - void initPrefetchStates(JoinProbeContext & context); - Block handleOtherConditions( JoinProbeContext & context, JoinProbeWorkerData & wd, From 9ee3b815d39c043177c203fee1f0dd6a2f1be3bc Mon Sep 17 00:00:00 2001 From: gengliqi Date: Mon, 21 Apr 2025 15:36:38 +0800 Subject: [PATCH 31/84] use < 0 instead of == -1 for output_index Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 8 ++++---- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 16113b06d30..1fdecff0b43 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -1229,7 +1229,7 @@ Block JoinProbeHelper::handleOtherConditions( for (size_t i = 0; i < right_columns; ++i) { auto output_index = output_column_indexes.at(left_columns + i); - if (output_index == -1) + if (output_index < 0) continue; if unlikely (!filter_offsets_is_initialized) init_filter_offsets(); @@ -1243,7 +1243,7 @@ Block JoinProbeHelper::handleOtherConditions( for (size_t i = 0; i < left_columns; ++i) { auto output_index = output_column_indexes.at(i); - if (output_index == -1) + if (output_index < 0) continue; auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(output_index); if (left_required_flag_for_other_condition[i]) @@ -1366,7 +1366,7 @@ Block JoinProbeHelper::fillNotMatchedRowsForLeftOuter(JoinProbeContext & context for (size_t i = 0; i < right_columns; ++i) { auto output_index = output_column_indexes.at(left_columns + i); - if (output_index == -1) + if (output_index < 0) continue; auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(output_index); des_column.column->assumeMutable()->insertManyDefaults(length); @@ -1375,7 +1375,7 @@ Block JoinProbeHelper::fillNotMatchedRowsForLeftOuter(JoinProbeContext & context for (size_t i = 0; i < left_columns; ++i) { auto output_index = output_column_indexes.at(i); - if (output_index == -1) + if (output_index < 0) continue; auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(output_index); auto & src_column = context.block.safeGetByPosition(i); diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 2747271b75d..1ebc52de6df 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -110,7 +110,7 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData class JoinProbeHelperUtil { protected: - explicit JoinProbeHelperUtil(const HashJoinSettings & settings, const HashJoinRowLayout & row_layout) + JoinProbeHelperUtil(const HashJoinSettings & settings, const HashJoinRowLayout & row_layout) : settings(settings) , row_layout(row_layout) {} From f75eee81947c2c2d9f6a82e8cff0074cf88861a3 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Mon, 21 Apr 2025 15:52:25 +0800 Subject: [PATCH 32/84] add comments Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoin.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index 099861b7c04..feff1f03b4b 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -475,7 +475,8 @@ void HashJoin::workAfterBuildRowFinish() enable_tagged_pointer, false); - const size_t lm_size_threshold = 16; + /// Conservative threshold: trigger late materialization when lm_row_size average >= 16 bytes. + constexpr size_t trigger_lm_row_size_threshold = 16; bool late_materialization = false; size_t avg_lm_row_size = 0; if (has_other_condition @@ -489,10 +490,11 @@ void HashJoin::workAfterBuildRowFinish() total_lm_row_count += build_workers_data[i].lm_row_count; } avg_lm_row_size = total_lm_row_count == 0 ? 0 : total_lm_row_size / total_lm_row_count; - late_materialization = avg_lm_row_size >= lm_size_threshold; + late_materialization = avg_lm_row_size >= trigger_lm_row_size_threshold; } fiu_do_on(FailPoints::force_join_v2_probe_enable_lm, { late_materialization = true; }); fiu_do_on(FailPoints::force_join_v2_probe_disable_lm, { late_materialization = false; }); + join_probe_helper = std::make_unique(this, late_materialization); LOG_INFO( From 66dd0d73d68394143c1084bbc604913ebebffb9c Mon Sep 17 00:00:00 2001 From: gengliqi Date: Mon, 21 Apr 2025 20:26:49 +0800 Subject: [PATCH 33/84] address comments Signed-off-by: gengliqi --- dbms/src/Flash/tests/gtest_join_executor.cpp | 77 ++++++++++- dbms/src/Interpreters/JoinV2/HashJoin.cpp | 35 ++--- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 124 +++++------------- 3 files changed, 120 insertions(+), 116 deletions(-) diff --git a/dbms/src/Flash/tests/gtest_join_executor.cpp b/dbms/src/Flash/tests/gtest_join_executor.cpp index aedbdca13c0..3f5ddcc5143 100644 --- a/dbms/src/Flash/tests/gtest_join_executor.cpp +++ b/dbms/src/Flash/tests/gtest_join_executor.cpp @@ -2202,7 +2202,7 @@ try toNullableVec("l4", {3, 1, 2, 0, 2, 3, 3, {}, 4, 5, 0, 6}), }); - // No other condition: lj_l_table left join lj_r_table + // lj_l_table left join lj_r_table, no other condition auto request = context.scan("test_db", "lj_l_table") .join( context.scan("test_db", "lj_r_table"), @@ -2288,7 +2288,7 @@ try }); WRAP_FOR_JOIN_TEST_END - // No other condition: lj_r_table left join lj_l_table + // lj_r_table left join lj_l_table, no other condition request = context.scan("test_db", "lj_r_table") .join( context.scan("test_db", "lj_l_table"), @@ -2354,7 +2354,7 @@ try }); WRAP_FOR_JOIN_TEST_END - // Has other condition: lj_l_table left join lj_r_table + // lj_l_table left join lj_r_table, other condition: l4 > r3 request = context.scan("test_db", "lj_l_table") .join( context.scan("test_db", "lj_r_table"), @@ -2390,7 +2390,7 @@ try }); WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END - // Has other condition: lj_r_table left join lj_l_table + // lj_r_table left join lj_l_table, other condition: l4 > r3 request = context.scan("test_db", "lj_r_table") .join( context.scan("test_db", "lj_l_table"), @@ -2424,6 +2424,75 @@ try toNullableVec({3, 3, 6, 4, 5, 6, {}}), }); WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END + + // lj_l_table left join lj_r_table, other condition: l4 > r3 or r3 < l3 + request + = context.scan("test_db", "lj_l_table") + .join( + context.scan("test_db", "lj_r_table"), + tipb::JoinType::TypeLeftOuterJoin, + {col("k1"), col("k2")}, + {eq(col("lj_l_table.l2"), col("lj_l_table.l3"))}, + {}, + {Or(gt(col("lj_l_table.l4"), col("lj_r_table.r3")), lt(col("lj_r_table.r3"), col("lj_l_table.l3")))}, + {}) + .project( + {"lj_l_table.l3", + "lj_r_table.r2", + "lj_r_table.r3", + "lj_r_table.k2", + "lj_l_table.k1", + "lj_l_table.l2", + "lj_l_table.l1", + "lj_r_table.r1"}) + .build(context); + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_BEGIN + executeAndAssertColumnsEqual( + request, + { + toNullableVec({1, 1, 2, 3, 4, 5, 0, {}, 6, 7, 7, 8, 9, 9, 10, 10}), + toNullableVec( + {"aaa", "bbb", {}, {}, {}, {}, {}, {}, "ddd", "ccc", "eee", "ddd", "aaa", "bbb", "ccc", "eee"}), + toNullableVec({1, 2, {}, {}, {}, {}, {}, {}, 4, 3, 5, 4, 1, 2, 3, 5}), + toNullableVec({1, 1, {}, {}, {}, {}, {}, {}, 3, 2, 2, 3, 1, 1, 2, 2}), + toNullableVec({1, 1, 1, 2, 2, {}, 3, 3, 3, 2, 2, 3, 1, 1, 2, 2}), + toVec({1, 1, 2, 3, 4, 5, 6, 6, 6, 7, 7, 8, 9, 9, 10, 10}), + toVec( + {"AAA", + "AAA", + "BBB", + "CCC", + "DDD", + "EEE", + "FFF", + "GGG", + "HHH", + "III", + "III", + "JJJ", + "KKK", + "KKK", + "LLL", + "LLL"}), + toNullableVec( + {"apple", + "banana", + {}, + {}, + {}, + {}, + {}, + {}, + "dog", + "cat", + "elephant", + "dog", + "apple", + "banana", + "cat", + "elephant"}), + }); + WRAP_FOR_JOIN_FOR_OTHER_CONDITION_TEST_END } CATCH diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index feff1f03b4b..8f9a71478b3 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -253,7 +253,7 @@ void HashJoin::initRowLayoutAndHashJoinMethod() required_columns_flag[i] = true; continue; } - auto & c = right_sample_block_pruned.safeGetByPosition(i); + auto & c = right_sample_block_pruned.getByPosition(i); if (required_columns_names_set_for_other_condition.contains(c.name)) { ++row_layout.other_column_count_for_other_condition; @@ -273,7 +273,7 @@ void HashJoin::initRowLayoutAndHashJoinMethod() { if (required_columns_flag[i]) continue; - auto & c = right_sample_block_pruned.safeGetByPosition(i); + auto & c = right_sample_block_pruned.getByPosition(i); if (c.column->valuesHaveFixedSize()) { row_layout.other_column_fixed_size += c.column->sizeOfValueIfFixed(); @@ -289,6 +289,9 @@ void HashJoin::initRowLayoutAndHashJoinMethod() c.name); } RUNTIME_CHECK(row_layout.raw_key_column_indexes.size() + row_layout.other_column_indexes.size() == columns); + for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) + RUNTIME_CHECK( + right_sample_block_pruned.safeGetByPosition(column_index).column->isColumnNullable() == is_nullable); } void HashJoin::initBuild(const Block & sample_block, size_t build_concurrency_) @@ -622,7 +625,7 @@ Block HashJoin::probeBlock(JoinProbeContext & context, size_t stream_index) context.prepareForHashProbe( method, kind, - non_equal_conditions.other_cond_expr != nullptr, + has_other_condition, key_names_left, non_equal_conditions.left_filter_column, probe_output_name_set, @@ -690,7 +693,7 @@ void HashJoin::initOutputBlock(Block & block) const size_t output_columns = output_block_after_finalize.columns(); for (size_t i = 0; i < output_columns; ++i) { - ColumnWithTypeAndName new_column = output_block_after_finalize.safeGetByPosition(i).cloneEmpty(); + ColumnWithTypeAndName new_column = output_block_after_finalize.getByPosition(i).cloneEmpty(); new_column.column->assumeMutable()->reserveAlign(settings.max_block_size, FULL_VECTOR_SIZE_AVX2); block.insert(std::move(new_column)); } @@ -763,10 +766,8 @@ void HashJoin::finalize(const Names & parent_require) if (!match_helper_name.empty()) output_columns_names_set_for_other_condition_after_finalize.insert(match_helper_name); - if (non_equal_conditions.other_cond_expr != nullptr) - { - const auto & actions = non_equal_conditions.other_cond_expr->getActions(); - for (const auto & action : actions) + auto update_required_columns_names_set = [&](const ExpressionActionsPtr & expr) { + for (const auto & action : expr->getActions()) { Names needed_columns = action.getNeededColumns(); for (const auto & name : needed_columns) @@ -775,21 +776,13 @@ void HashJoin::finalize(const Names & parent_require) required_columns_names_set_for_other_condition.insert(name); } } - } + }; + + if (non_equal_conditions.other_cond_expr != nullptr) + update_required_columns_names_set(non_equal_conditions.other_cond_expr); if (non_equal_conditions.null_aware_eq_cond_expr != nullptr) - { - const auto & actions = non_equal_conditions.null_aware_eq_cond_expr->getActions(); - for (const auto & action : actions) - { - Names needed_columns = action.getNeededColumns(); - for (const auto & name : needed_columns) - { - if (output_columns_names_set_for_other_condition_after_finalize.contains(name)) - required_columns_names_set_for_other_condition.insert(name); - } - } - } + update_required_columns_names_set(non_equal_conditions.null_aware_eq_cond_expr); } /// remove duplicated column diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 1fdecff0b43..152f794f186 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -126,62 +126,46 @@ void JoinProbeContext::prepareForHashProbe( template void JoinProbeHelperUtil::flushInsertBatch(JoinProbeWorkerData & wd, MutableColumns & added_columns) const { - if constexpr (late_materialization) + for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) { - size_t idx = 0; - for (auto [_, is_nullable] : row_layout.raw_key_column_indexes) - { - IColumn * column = added_columns[idx].get(); - if (is_nullable) - column = &static_cast(*added_columns[idx]).getNestedColumn(); - column->deserializeAndInsertFromPos(wd.insert_batch, true); - ++idx; - } - for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) - added_columns[idx++]->deserializeAndInsertFromPos(wd.insert_batch, true); - - wd.row_ptrs_for_lm.insert(wd.insert_batch.begin(), wd.insert_batch.end()); + IColumn * column = added_columns[column_index].get(); + if (is_nullable) + column = &static_cast(*added_columns[column_index]).getNestedColumn(); + column->deserializeAndInsertFromPos(wd.insert_batch, true); } + + size_t add_size; + if constexpr (late_materialization) + add_size = row_layout.other_column_count_for_other_condition; else + add_size = row_layout.other_column_indexes.size(); + for (size_t i = 0; i < add_size; ++i) + { + size_t column_index = row_layout.other_column_indexes[i].first; + added_columns[column_index]->deserializeAndInsertFromPos(wd.insert_batch, true); + } + if constexpr (late_materialization) + wd.row_ptrs_for_lm.insert(wd.insert_batch.begin(), wd.insert_batch.end()); + + if constexpr (last_flush) { for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) { IColumn * column = added_columns[column_index].get(); if (is_nullable) column = &static_cast(*added_columns[column_index]).getNestedColumn(); - column->deserializeAndInsertFromPos(wd.insert_batch, true); + column->flushNTAlignBuffer(); } - for (auto [column_index, _] : row_layout.other_column_indexes) - added_columns[column_index]->deserializeAndInsertFromPos(wd.insert_batch, true); - } - if constexpr (last_flush) - { + size_t add_size; if constexpr (late_materialization) - { - size_t idx = 0; - for (auto [_, is_nullable] : row_layout.raw_key_column_indexes) - { - IColumn * column = added_columns[idx].get(); - if (is_nullable) - column = &static_cast(*added_columns[idx]).getNestedColumn(); - column->flushNTAlignBuffer(); - ++idx; - } - for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) - added_columns[idx++]->flushNTAlignBuffer(); - } + add_size = row_layout.other_column_count_for_other_condition; else + add_size = row_layout.other_column_indexes.size(); + for (size_t i = 0; i < add_size; ++i) { - for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) - { - IColumn * column = added_columns[column_index].get(); - if (is_nullable) - column = &static_cast(*added_columns[column_index]).getNestedColumn(); - column->flushNTAlignBuffer(); - } - for (auto [column_index, _] : row_layout.other_column_indexes) - added_columns[column_index]->flushNTAlignBuffer(); + size_t column_index = row_layout.other_column_indexes[i].first; + added_columns[column_index]->flushNTAlignBuffer(); } } @@ -191,23 +175,16 @@ void JoinProbeHelperUtil::flushInsertBatch(JoinProbeWorkerData & wd, MutableColu template void JoinProbeHelperUtil::fillNullMapWithZero(MutableColumns & added_columns) const { - size_t idx = 0; for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) { if (is_nullable) { - size_t index; - if constexpr (late_materialization) - index = idx; - else - index = column_index; - auto & nullable_column = static_cast(*added_columns[index]); + auto & nullable_column = static_cast(*added_columns[column_index]); size_t data_size = nullable_column.getNestedColumn().size(); size_t nullmap_size = nullable_column.getNullMapColumn().size(); RUNTIME_CHECK(nullmap_size <= data_size); nullable_column.getNullMapColumn().getData().resize_fill_zero(data_size); } - ++idx; } } @@ -574,30 +551,9 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData } } - MutableColumns added_columns; - if constexpr (late_materialization) - { - for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) - { - added_columns.emplace_back( - wd.result_block.safeGetByPosition(left_columns + column_index).column->assumeMutable()); - RUNTIME_CHECK(added_columns.back()->isColumnNullable() == is_nullable); - } - for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) - { - size_t column_index = row_layout.other_column_indexes[i].first; - added_columns.emplace_back( - wd.result_block.safeGetByPosition(left_columns + column_index).column->assumeMutable()); - } - } - else - { - added_columns.resize(right_columns); - for (size_t i = 0; i < right_columns; ++i) - added_columns[i] = wd.result_block.safeGetByPosition(left_columns + i).column->assumeMutable(); - for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) - RUNTIME_CHECK(added_columns.at(column_index)->isColumnNullable() == is_nullable); - } + MutableColumns added_columns(right_columns); + for (size_t i = 0; i < right_columns; ++i) + added_columns[i] = wd.result_block.safeGetByPosition(left_columns + i).column->assumeMutable(); Stopwatch watch; if (pointer_table.enableProbePrefetch()) @@ -615,6 +571,9 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData added_columns); wd.probe_hash_table_time += watch.elapsedFromLastTime(); + for (size_t i = 0; i < right_columns; ++i) + wd.result_block.safeGetByPosition(left_columns + i).column = std::move(added_columns[i]); + if constexpr (kind == Inner || kind == LeftOuter || kind == Semi || kind == Anti) { if (wd.selective_offsets.empty()) @@ -626,23 +585,6 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData return genResultBlockForLeftOuterSemi(context, wd); } - if constexpr (late_materialization) - { - size_t idx = 0; - for (auto [column_index, _] : row_layout.raw_key_column_indexes) - wd.result_block.safeGetByPosition(left_columns + column_index).column = std::move(added_columns[idx++]); - for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) - { - size_t column_index = row_layout.other_column_indexes[i].first; - wd.result_block.safeGetByPosition(left_columns + column_index).column = std::move(added_columns[idx++]); - } - } - else - { - for (size_t i = 0; i < right_columns; ++i) - wd.result_block.safeGetByPosition(left_columns + i).column = std::move(added_columns[i]); - } - if constexpr (has_other_condition) { // Always using late materialization for left side columns From c6966d3e4adaca8227f0742aa85499b42066bb40 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Mon, 21 Apr 2025 21:26:47 +0800 Subject: [PATCH 34/84] add comments Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 152f794f186..8feea5f0230 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -571,6 +571,8 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData added_columns); wd.probe_hash_table_time += watch.elapsedFromLastTime(); + // Move the mutable column pointers back into the wd.result_block, dropping the extra reference (ref_count 2→1). + // Alternative: added_columns.clear(); but that is less explicit and may misleadingly imply the columns are discarded. for (size_t i = 0; i < right_columns; ++i) wd.result_block.safeGetByPosition(left_columns + i).column = std::move(added_columns[i]); From 338e3a0b95e611f9154ebe0d1b254c062821d508 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 22 Apr 2025 14:49:47 +0800 Subject: [PATCH 35/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h | 2 +- tests/fullstack-test/mpp/rollup.test | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h b/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h index c0b921f5289..16d1c640082 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h @@ -54,7 +54,7 @@ constexpr size_t ROW_PTR_TAG_SHIFT = 8 * sizeof(RowPtr) - ROW_PTR_TAG_BITS; static_assert(sizeof(RowPtr) == sizeof(uintptr_t)); static_assert(sizeof(RowPtr) == 8); -inline RowPtr getNextRowPtr(const RowPtr ptr) +inline RowPtr getNextRowPtr(RowPtr ptr) { return unalignedLoad(ptr); } diff --git a/tests/fullstack-test/mpp/rollup.test b/tests/fullstack-test/mpp/rollup.test index 84a24620ea6..3db530da41d 100644 --- a/tests/fullstack-test/mpp/rollup.test +++ b/tests/fullstack-test/mpp/rollup.test @@ -123,4 +123,4 @@ mysql> use test; set tidb_enforce_mpp=1; select d, a, avg(d), sum(a), min(b), ma | 7 | 6 | 7.0000 | 6 | 6 | NULL | +------+------+--------+--------+--------+--------+ -mysql> drop table if exists test.t1; \ No newline at end of file +mysql> drop table if exists test.t1; From 70c0e71ca54fd78419fac4bb0a618031064f4736 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 22 Apr 2025 16:19:43 +0800 Subject: [PATCH 36/84] u Signed-off-by: gengliqi --- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 119 ++++++------------ 1 file changed, 37 insertions(+), 82 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 8feea5f0230..25718d342e3 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -197,7 +197,6 @@ struct JoinProbeAdder static bool ALWAYS_INLINE addMatched( JoinProbeHelper & helper, - JoinProbeContext &, JoinProbeWorkerData & wd, MutableColumns & added_columns, size_t idx, @@ -211,8 +210,7 @@ struct JoinProbeAdder return current_offset >= helper.settings.max_block_size; } - static bool ALWAYS_INLINE - addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &) + static bool ALWAYS_INLINE addNotMatched(JoinProbeHelper &, JoinProbeWorkerData &, size_t, size_t &) { return false; } @@ -233,7 +231,6 @@ struct JoinProbeAdder static bool ALWAYS_INLINE addMatched( JoinProbeHelper & helper, - JoinProbeContext &, JoinProbeWorkerData & wd, MutableColumns & added_columns, size_t idx, @@ -247,13 +244,8 @@ struct JoinProbeAdder return current_offset >= helper.settings.max_block_size; } - static bool ALWAYS_INLINE addNotMatched( - JoinProbeHelper & helper, - JoinProbeContext &, - JoinProbeWorkerData & wd, - MutableColumns &, - size_t idx, - size_t & current_offset) + static bool ALWAYS_INLINE + addNotMatched(JoinProbeHelper & helper, JoinProbeWorkerData & wd, size_t idx, size_t & current_offset) { if constexpr (!has_other_condition) { @@ -294,7 +286,6 @@ struct JoinProbeAdder static bool ALWAYS_INLINE addMatched( JoinProbeHelper & helper, - JoinProbeContext &, JoinProbeWorkerData & wd, MutableColumns &, size_t idx, @@ -307,8 +298,7 @@ struct JoinProbeAdder return current_offset >= helper.settings.max_block_size; } - static bool ALWAYS_INLINE - addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &) + static bool ALWAYS_INLINE addNotMatched(JoinProbeHelper &, JoinProbeWorkerData &, size_t, size_t &) { return false; } @@ -323,26 +313,14 @@ struct JoinProbeAdder static constexpr bool need_not_matched = true; static constexpr bool break_on_first_match = true; - static bool ALWAYS_INLINE addMatched( - JoinProbeHelper &, - JoinProbeContext &, - JoinProbeWorkerData &, - MutableColumns &, - size_t, - size_t &, - RowPtr, - size_t) + static bool ALWAYS_INLINE + addMatched(JoinProbeHelper &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &, RowPtr, size_t) { return false; } - static bool ALWAYS_INLINE addNotMatched( - JoinProbeHelper & helper, - JoinProbeContext &, - JoinProbeWorkerData & wd, - MutableColumns &, - size_t idx, - size_t & current_offset) + static bool ALWAYS_INLINE + addNotMatched(JoinProbeHelper & helper, JoinProbeWorkerData & wd, size_t idx, size_t & current_offset) { ++current_offset; wd.selective_offsets.push_back(idx); @@ -359,22 +337,14 @@ struct JoinProbeAdder static constexpr bool need_not_matched = false; static constexpr bool break_on_first_match = true; - static bool ALWAYS_INLINE addMatched( - JoinProbeHelper &, - JoinProbeContext &, - JoinProbeWorkerData & wd, - MutableColumns &, - size_t idx, - size_t &, - RowPtr, - size_t) + static bool ALWAYS_INLINE + addMatched(JoinProbeHelper &, JoinProbeWorkerData & wd, MutableColumns &, size_t idx, size_t &, RowPtr, size_t) { wd.match_helper_res[idx] = 1; return false; } - static bool ALWAYS_INLINE - addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &) + static bool ALWAYS_INLINE addNotMatched(JoinProbeHelper &, JoinProbeWorkerData &, size_t, size_t &) { return false; } @@ -389,26 +359,13 @@ struct JoinProbeAdder static constexpr bool need_not_matched = true; static constexpr bool break_on_first_match = true; - static bool ALWAYS_INLINE addMatched( - JoinProbeHelper &, - JoinProbeContext &, - JoinProbeWorkerData &, - MutableColumns &, - size_t, - size_t &, - RowPtr, - size_t) + static bool ALWAYS_INLINE + addMatched(JoinProbeHelper &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &, RowPtr, size_t) { return false; } - static bool ALWAYS_INLINE addNotMatched( - JoinProbeHelper &, - JoinProbeContext &, - JoinProbeWorkerData & wd, - MutableColumns &, - size_t idx, - size_t &) + static bool ALWAYS_INLINE addNotMatched(JoinProbeHelper &, JoinProbeWorkerData & wd, size_t idx, size_t &) { wd.match_helper_res[idx] = 1; return false; @@ -650,19 +607,19 @@ void JoinProbeHelper::probeFillColumns( key_offset += sizeof(HashValueType); } -#define NOT_MATCHED(not_matched) \ - if constexpr (Adder::need_not_matched) \ - { \ - assert(ptr == nullptr); \ - if (not_matched) \ - { \ - bool is_end = Adder::addNotMatched(*this, context, wd, added_columns, idx, current_offset); \ - if unlikely (is_end) \ - { \ - ++idx; \ - break; \ - } \ - } \ +#define NOT_MATCHED(not_matched) \ + if constexpr (Adder::need_not_matched) \ + { \ + assert(ptr == nullptr); \ + if (not_matched) \ + { \ + bool is_end = Adder::addNotMatched(*this, wd, idx, current_offset); \ + if unlikely (is_end) \ + { \ + ++idx; \ + break; \ + } \ + } \ } for (; idx < context.rows; ++idx) @@ -715,7 +672,6 @@ void JoinProbeHelper::probeFillColumns( { bool is_end = Adder::addMatched( *this, - context, wd, added_columns, idx, @@ -825,15 +781,15 @@ void JoinProbeHelper::probeFillColumnsPrefetch( key_offset += sizeof(HashValueType); } -#define NOT_MATCHED(not_matched, idx) \ - if constexpr (Adder::need_not_matched) \ - { \ - if (not_matched) \ - { \ - bool is_end = Adder::addNotMatched(*this, context, wd, added_columns, idx, current_offset); \ - if unlikely (is_end) \ - break; \ - } \ +#define NOT_MATCHED(not_matched, idx) \ + if constexpr (Adder::need_not_matched) \ + { \ + if (not_matched) \ + { \ + bool is_end = Adder::addNotMatched(*this, wd, idx, current_offset); \ + if unlikely (is_end) \ + break; \ + } \ } const size_t probe_prefetch_step = settings.probe_prefetch_step; @@ -859,7 +815,6 @@ void JoinProbeHelper::probeFillColumnsPrefetch( { bool is_end = Adder::addMatched( *this, - context, wd, added_columns, state->index, @@ -939,7 +894,7 @@ void JoinProbeHelper::probeFillColumnsPrefetch( if constexpr (Adder::need_not_matched) { - is_end = Adder::addNotMatched(*this, context, wd, added_columns, idx, current_offset); + is_end = Adder::addNotMatched(*this, wd, idx, current_offset); if unlikely (is_end) { ++idx; From f4d41370f820ee0d4b449c85746e9d73bb74da74 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 22 Apr 2025 16:40:45 +0800 Subject: [PATCH 37/84] rename Signed-off-by: gengliqi --- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 62 +++++++++---------- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 4 +- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 25718d342e3..a58cc7ede0d 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -1023,24 +1023,24 @@ Block JoinProbeHelper::handleOtherConditions( size_t remaining_insert_size = settings.max_block_size - wd.result_block_for_other_condition.rows(); size_t result_size = countBytesInFilter(wd.filter); - bool filter_offsets_is_initialized = false; - auto init_filter_offsets = [&]() { + bool block_filter_offsets_is_initialized = false; + auto init_block_filter_offsets = [&]() { RUNTIME_CHECK(wd.filter.size() == rows); - wd.filter_offsets.clear(); - wd.filter_offsets.reserve(result_size); - filterImpl(&wd.filter[0], &wd.filter[rows], &BASE_OFFSETS[0], wd.filter_offsets); - RUNTIME_CHECK(wd.filter_offsets.size() == result_size); - filter_offsets_is_initialized = true; + wd.block_filter_offsets.clear(); + wd.block_filter_offsets.reserve(result_size); + filterImpl(&wd.filter[0], &wd.filter[rows], &BASE_OFFSETS[0], wd.block_filter_offsets); + RUNTIME_CHECK(wd.block_filter_offsets.size() == result_size); + block_filter_offsets_is_initialized = true; }; - bool filter_selective_offsets_is_initialized = false; - auto init_filter_selective_offsets = [&]() { + bool result_block_filter_offsets_is_initialized = false; + auto init_result_block_filter_offsets = [&]() { RUNTIME_CHECK(wd.selective_offsets.size() == rows); - wd.filter_selective_offsets.clear(); - wd.filter_selective_offsets.reserve(result_size); - filterImpl(&wd.filter[0], &wd.filter[rows], &wd.selective_offsets[0], wd.filter_selective_offsets); - RUNTIME_CHECK(wd.filter_selective_offsets.size() == result_size); - filter_selective_offsets_is_initialized = true; + wd.result_block_filter_offsets.clear(); + wd.result_block_filter_offsets.reserve(result_size); + filterImpl(&wd.filter[0], &wd.filter[rows], &wd.selective_offsets[0], wd.result_block_filter_offsets); + RUNTIME_CHECK(wd.result_block_filter_offsets.size() == result_size); + result_block_filter_offsets_is_initialized = true; }; bool filter_row_ptrs_for_lm_is_initialized = false; @@ -1064,12 +1064,12 @@ Block JoinProbeHelper::handleOtherConditions( auto output_index = output_column_indexes.at(left_columns + column_index); if (output_index < 0) continue; - if unlikely (!filter_offsets_is_initialized) - init_filter_offsets(); + if unlikely (!block_filter_offsets_is_initialized) + init_block_filter_offsets(); auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(output_index); auto & src_column = wd.result_block.safeGetByPosition(left_columns + column_index); des_column.column->assumeMutable() - ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); + ->insertSelectiveRangeFrom(*src_column.column.get(), wd.block_filter_offsets, start, length); } for (size_t i = 0; i < row_layout.other_column_count_for_other_condition; ++i) { @@ -1077,12 +1077,12 @@ Block JoinProbeHelper::handleOtherConditions( auto output_index = output_column_indexes.at(left_columns + column_index); if (output_index < 0) continue; - if unlikely (!filter_offsets_is_initialized) - init_filter_offsets(); + if unlikely (!block_filter_offsets_is_initialized) + init_block_filter_offsets(); auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(output_index); auto & src_column = wd.result_block.safeGetByPosition(left_columns + column_index); des_column.column->assumeMutable() - ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); + ->insertSelectiveRangeFrom(*src_column.column.get(), wd.block_filter_offsets, start, length); } if (!filter_row_ptrs_for_lm_is_initialized) @@ -1130,12 +1130,12 @@ Block JoinProbeHelper::handleOtherConditions( auto output_index = output_column_indexes.at(left_columns + i); if (output_index < 0) continue; - if unlikely (!filter_offsets_is_initialized) - init_filter_offsets(); + if unlikely (!block_filter_offsets_is_initialized) + init_block_filter_offsets(); auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(output_index); auto & src_column = wd.result_block.safeGetByPosition(left_columns + i); des_column.column->assumeMutable() - ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); + ->insertSelectiveRangeFrom(*src_column.column.get(), wd.block_filter_offsets, start, length); } } @@ -1147,30 +1147,30 @@ Block JoinProbeHelper::handleOtherConditions( auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(output_index); if (left_required_flag_for_other_condition[i]) { - if unlikely (!filter_offsets_is_initialized && !filter_selective_offsets_is_initialized) - init_filter_selective_offsets(); - if (filter_offsets_is_initialized) + if unlikely (!block_filter_offsets_is_initialized && !result_block_filter_offsets_is_initialized) + init_result_block_filter_offsets(); + if (block_filter_offsets_is_initialized) { auto & src_column = wd.result_block.safeGetByPosition(i); des_column.column->assumeMutable() - ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_offsets, start, length); + ->insertSelectiveRangeFrom(*src_column.column.get(), wd.block_filter_offsets, start, length); } else { auto & src_column = context.block.safeGetByPosition(i); des_column.column->assumeMutable()->insertSelectiveRangeFrom( *src_column.column.get(), - wd.filter_selective_offsets, + wd.result_block_filter_offsets, start, length); } continue; } - if unlikely (!filter_selective_offsets_is_initialized) - init_filter_selective_offsets(); + if unlikely (!result_block_filter_offsets_is_initialized) + init_result_block_filter_offsets(); auto & src_column = context.block.safeGetByPosition(i); des_column.column->assumeMutable() - ->insertSelectiveRangeFrom(*src_column.column.get(), wd.filter_selective_offsets, start, length); + ->insertSelectiveRangeFrom(*src_column.column.get(), wd.result_block_filter_offsets, start, length); } }; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 1ebc52de6df..9bded64fdec 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -87,8 +87,8 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData /// For other condition ColumnVector::Container filter; - IColumn::Offsets filter_offsets; - IColumn::Offsets filter_selective_offsets; + IColumn::Offsets block_filter_offsets; + IColumn::Offsets result_block_filter_offsets; /// For late materialization RowPtrs row_ptrs_for_lm; RowPtrs filter_row_ptrs_for_lm; From 80800b7b21fe91d41f32554cbeedfb829ff598f5 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 22 Apr 2025 17:51:39 +0800 Subject: [PATCH 38/84] rename Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 18 +++++++++--------- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 8 ++++---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index a58cc7ede0d..c6c67a08f16 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -37,7 +37,7 @@ using enum ASTTableJoin::Kind; bool JoinProbeContext::isProbeFinished() const { - return start_row_idx >= rows + return current_row_idx >= rows // For prefetching && prefetch_active_states == 0; } @@ -54,8 +54,8 @@ void JoinProbeContext::resetBlock(Block & block_) block = block_; orignal_block = block_; rows = block.rows(); - start_row_idx = 0; - current_row_ptr = nullptr; + current_row_idx = 0; + current_build_row_ptr = nullptr; current_row_is_matched = false; prefetch_active_states = 0; @@ -597,8 +597,8 @@ void JoinProbeHelper::probeFillColumns( auto & key_getter = *static_cast(context.key_getter.get()); size_t current_offset = wd.result_block.rows(); - size_t idx = context.start_row_idx; - RowPtr ptr = context.current_row_ptr; + size_t idx = context.current_row_idx; + RowPtr ptr = context.current_build_row_ptr; bool is_matched = context.current_row_is_matched; size_t collision = 0; size_t key_offset = sizeof(RowPtr); @@ -710,8 +710,8 @@ void JoinProbeHelper::probeFillColumns( Adder::flush(*this, wd, added_columns); - context.start_row_idx = idx; - context.current_row_ptr = ptr; + context.current_row_idx = idx; + context.current_build_row_ptr = ptr; context.current_row_is_matched = is_matched; wd.collision += collision; @@ -770,7 +770,7 @@ void JoinProbeHelper::probeFillColumnsPrefetch( } auto * states = static_cast *>(context.prefetch_states.get()); - size_t idx = context.start_row_idx; + size_t idx = context.current_row_idx; size_t active_states = context.prefetch_active_states; size_t k = context.prefetch_iter; size_t current_offset = wd.result_block.rows(); @@ -939,7 +939,7 @@ void JoinProbeHelper::probeFillColumnsPrefetch( Adder::flush(*this, wd, added_columns); - context.start_row_idx = idx; + context.current_row_idx = idx; context.prefetch_active_states = active_states; context.prefetch_iter = k; wd.collision += collision; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 9bded64fdec..945cd245120 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -36,8 +36,8 @@ struct JoinProbeContext /// original_block ensures that the reference counts for the key columns are never zero. Block orignal_block; size_t rows = 0; - size_t start_row_idx = 0; - RowPtr current_row_ptr = nullptr; + size_t current_row_idx = 0; + RowPtr current_build_row_ptr = nullptr; /// For left outer/(left outer) (anti) semi join without other conditions. bool current_row_is_matched = false; /// For left outer/(left outer) (anti) semi join with other conditions. @@ -78,9 +78,9 @@ struct JoinProbeContext struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData { IColumn::Offsets selective_offsets; - /// For left outer join with no other condition + /// For left outer join without other conditions IColumn::Offsets not_matched_selective_offsets; - /// For left outer (anti) semi join with no other condition + /// For left outer (anti) semi join without other conditions PaddedPODArray match_helper_res; RowPtrs insert_batch; From 302fab5eedbf0c96e5b17941e1e1563407b990df Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 22 Apr 2025 17:53:09 +0800 Subject: [PATCH 39/84] u Signed-off-by: gengliqi --- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index c6c67a08f16..3233cd74fd5 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -410,26 +410,26 @@ JoinProbeHelper::JoinProbeHelper(const HashJoin * join, bool late_materializatio CALL2(KeyGetter, JoinType, false, false) \ } -#define CALL(KeyGetter) \ - { \ - auto kind = join->kind; \ - bool has_other_condition = join->has_other_condition; \ - if (kind == Inner) \ - CALL1(KeyGetter, Inner) \ - else if (kind == LeftOuter) \ - CALL1(KeyGetter, LeftOuter) \ - else if (kind == Semi && !has_other_condition) \ - CALL2(KeyGetter, Semi, false, false) \ - else if (kind == Anti && !has_other_condition) \ - CALL2(KeyGetter, Anti, false, false) \ - else if (kind == LeftOuterSemi && !has_other_condition) \ - CALL2(KeyGetter, LeftOuterSemi, false, false) \ - else if (kind == LeftOuterAnti && !has_other_condition) \ - CALL2(KeyGetter, LeftOuterAnti, false, false) \ - else \ - throw Exception( \ - fmt::format("Logical error: unknown combination of JOIN {}", magic_enum::enum_name(join->kind)), \ - ErrorCodes::LOGICAL_ERROR); \ +#define CALL(KeyGetter) \ + { \ + auto kind = join->kind; \ + bool has_other_condition = join->has_other_condition; \ + if (kind == Inner) \ + CALL1(KeyGetter, Inner) \ + else if (kind == LeftOuter) \ + CALL1(KeyGetter, LeftOuter) \ + else if (kind == Semi && !has_other_condition) \ + CALL2(KeyGetter, Semi, false, false) \ + else if (kind == Anti && !has_other_condition) \ + CALL2(KeyGetter, Anti, false, false) \ + else if (kind == LeftOuterSemi && !has_other_condition) \ + CALL2(KeyGetter, LeftOuterSemi, false, false) \ + else if (kind == LeftOuterAnti && !has_other_condition) \ + CALL2(KeyGetter, LeftOuterAnti, false, false) \ + else \ + throw Exception( \ + fmt::format("Logical error: unknown combination of JOIN {}", magic_enum::enum_name(kind)), \ + ErrorCodes::LOGICAL_ERROR); \ } switch (join->method) From d7836d5dae87293d4e7959b329402442744b6419 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 22 Apr 2025 18:00:23 +0800 Subject: [PATCH 40/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 4 ++-- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 3233cd74fd5..acdce9877f0 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -61,7 +61,7 @@ void JoinProbeContext::resetBlock(Block & block_) prefetch_active_states = 0; is_prepared = false; - materialized_columns.clear(); + materialized_key_columns.clear(); key_columns.clear(); null_map = nullptr; null_map_holder = nullptr; @@ -81,7 +81,7 @@ void JoinProbeContext::prepareForHashProbe( if (is_prepared) return; - key_columns = extractAndMaterializeKeyColumns(block, materialized_columns, key_names); + key_columns = extractAndMaterializeKeyColumns(block, materialized_key_columns, key_names); /// Keys with NULL value in any column won't join to anything. extractNestedColumnsAndNullMap(key_columns, null_map_holder, null_map); /// reuse null_map to record the filtered rows, the rows contains NULL or does not diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 945cd245120..672aa384820 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -51,7 +51,7 @@ struct JoinProbeContext std::unique_ptr> prefetch_states; bool is_prepared = false; - Columns materialized_columns; + Columns materialized_key_columns; ColumnRawPtrs key_columns; ColumnPtr null_map_holder = nullptr; ConstNullMapPtr null_map = nullptr; From 1450e9e29de18e9cbdcbdc0882b7a5b7709d4fbd Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 22 Apr 2025 20:45:34 +0800 Subject: [PATCH 41/84] add SemiJoinProbe Signed-off-by: gengliqi --- dbms/CMakeLists.txt | 1 + dbms/src/Interpreters/JoinV2/HashJoin.h | 2 + .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 14 +- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 4 +- .../src/Interpreters/JoinV2/SemiJoinProbe.cpp | 111 +++++++++++++++ dbms/src/Interpreters/JoinV2/SemiJoinProbe.h | 127 ++++++++++++++++++ 6 files changed, 257 insertions(+), 2 deletions(-) create mode 100644 dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp create mode 100644 dbms/src/Interpreters/JoinV2/SemiJoinProbe.h diff --git a/dbms/CMakeLists.txt b/dbms/CMakeLists.txt index e8cc291163e..af756df9205 100644 --- a/dbms/CMakeLists.txt +++ b/dbms/CMakeLists.txt @@ -113,6 +113,7 @@ check_then_add_sources_compile_flag ( src/Interpreters/JoinV2/HashJoin.cpp src/Interpreters/JoinV2/HashJoinBuild.cpp src/Interpreters/JoinV2/HashJoinProbe.cpp + src/Interpreters/JoinV2/SemiJoinProbe.cpp src/IO/Compression/EncodingUtil.cpp src/Storages/DeltaMerge/BitmapFilter/BitmapFilter.cpp src/Storages/DeltaMerge/DMVersionFilterBlockInputStream.cpp diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.h b/dbms/src/Interpreters/JoinV2/HashJoin.h index 077341a4dbd..ffdab906747 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.h +++ b/dbms/src/Interpreters/JoinV2/HashJoin.h @@ -26,6 +26,7 @@ #include #include #include +#include namespace DB @@ -82,6 +83,7 @@ class HashJoin private: friend JoinProbeHelper; + friend SemiJoinProbeHelper; static const DataTypePtr match_helper_type; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index acdce9877f0..bd23eed45a5 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -120,6 +121,17 @@ void JoinProbeContext::prepareForHashProbe( not_matched_offsets.clear(); } + if ((kind == Semi || kind == Anti || kind == LeftOuterSemi || kind == LeftOuterAnti) && has_other_condition) + { + if unlikely (!semi_join_pending_probe_list) + { + semi_join_pending_probe_list = decltype(semi_join_pending_probe_list)( + static_cast(new SemiJoinPendingProbeList), + [](void * ptr) { delete static_cast(ptr); }); + } + static_cast(semi_join_pending_probe_list.get())->reset(block.rows() + 1); + } + is_prepared = true; } @@ -762,7 +774,7 @@ void JoinProbeHelper::probeFillColumnsPrefetch( using Adder = JoinProbeAdder; auto & key_getter = *static_cast(context.key_getter.get()); - if (!context.prefetch_states) + if unlikely (!context.prefetch_states) { context.prefetch_states = decltype(context.prefetch_states)( static_cast(new ProbePrefetchState[settings.probe_prefetch_step]), diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 672aa384820..0e293ae98ac 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -40,11 +40,13 @@ struct JoinProbeContext RowPtr current_build_row_ptr = nullptr; /// For left outer/(left outer) (anti) semi join without other conditions. bool current_row_is_matched = false; - /// For left outer/(left outer) (anti) semi join with other conditions. + /// For left outer with other conditions. IColumn::Filter rows_not_matched; /// < 0 means not_matched_offsets is not initialized. ssize_t not_matched_offsets_idx = -1; IColumn::Offsets not_matched_offsets; + /// For (left outer) (anti) semi join with other conditions. + std::unique_ptr> semi_join_pending_probe_list; size_t prefetch_active_states = 0; size_t prefetch_iter = 0; diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp new file mode 100644 index 00000000000..62bfcff54dc --- /dev/null +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp @@ -0,0 +1,111 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#ifdef TIFLASH_ENABLE_AVX_SUPPORT +ASSERT_USE_AVX2_COMPILE_FLAG +#endif + +namespace DB +{ + +using enum ASTTableJoin::Kind; + +enum class SemiJoinProbeResType : UInt8 +{ + FALSE_VALUE, + TRUE_VALUE, + NULL_VALUE, +}; + +SemiJoinProbeHelper::SemiJoinProbeHelper(const HashJoin * join) + : JoinProbeHelperUtil(join->settings, join->row_layout) + , join(join) + , pointer_table(join->pointer_table) +{ + RUNTIME_CHECK(join->has_other_condition); + +#define CALL2(KeyGetter, JoinType, tagged_pointer) \ + { \ + func_ptr_has_null = &SemiJoinProbeHelper::probeImpl; \ + func_ptr_no_null = &SemiJoinProbeHelper::probeImpl; \ + } + +#define CALL1(KeyGetter, JoinType) \ + { \ + if (pointer_table.enableTaggedPointer()) \ + CALL2(KeyGetter, JoinType, true) \ + else \ + CALL2(KeyGetter, JoinType, false) \ + } + +#define CALL(KeyGetter) \ + { \ + auto kind = join->kind; \ + if (kind == Semi) \ + CALL1(KeyGetter, Semi) \ + else if (kind == Anti) \ + CALL1(KeyGetter, Anti) \ + else if (kind == LeftOuterSemi) \ + CALL1(KeyGetter, LeftOuterSemi) \ + else if (kind == LeftOuterAnti) \ + CALL1(KeyGetter, LeftOuterAnti) \ + else \ + throw Exception( \ + fmt::format("Logical error: unknown combination of JOIN {}", magic_enum::enum_name(kind)), \ + ErrorCodes::LOGICAL_ERROR); \ + } + + switch (join->method) + { +#define M(METHOD) \ + case HashJoinKeyMethod::METHOD: \ + using KeyGetterType##METHOD = HashJoinKeyGetterForType; \ + CALL(KeyGetterType##METHOD); \ + break; + APPLY_FOR_HASH_JOIN_VARIANTS(M) +#undef M + + default: + throw Exception( + fmt::format("Unknown JOIN keys variant {}.", magic_enum::enum_name(join->method)), + ErrorCodes::UNKNOWN_SET_DATA_VARIANT); + } +#undef CALL +#undef CALL1 +#undef CALL2 +} + +Block SemiJoinProbeHelper::probe(JoinProbeContext & context, JoinProbeWorkerData & wd) +{ + if (context.null_map) + return (this->*func_ptr_has_null)(context, wd); + else + return (this->*func_ptr_no_null)(context, wd); +} + +template +Block SemiJoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData & wd) +{ + if unlikely (context.rows == 0) + return join->output_block_after_finalize; + + auto * probe_list = static_cast(context.semi_join_pending_probe_list.get()); + RUNTIME_CHECK(probe_list->size() == context.rows); +} + + +} // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h new file mode 100644 index 00000000000..d2675347718 --- /dev/null +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h @@ -0,0 +1,127 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace DB +{ + +/// A index‑based doubly‑linked circular list for managing semi join pending probe rows. +/// After reset(N), it holds N entries plus a sentinel at index 0, and supports O(1) insert/remove by index. +class SemiJoinPendingProbeList +{ +public: + using index_t = UInt32; + + struct PendingProbeRow + { + RowPtr build_row_ptr; + UInt32 pace; + /// Embedded list pointers + UInt32 prev_node; + UInt32 next_node; + }; + + class Iterator + { + public: + Iterator(SemiJoinPendingProbeList * list, index_t idx) + : list(list) + , idx(idx) + {} + + PendingProbeRow & operator*() const { return list->probe_rows[idx]; } + PendingProbeRow * operator->() const { return &list->probe_rows[idx]; } + + Iterator & operator++() + { + idx = list->probe_rows[idx].next_node; + return *this; + } + + bool operator!=(const Iterator & other) const { return idx != other.idx; } + + private: + SemiJoinPendingProbeList * list; + index_t idx; + }; + + SemiJoinPendingProbeList() = default; + + void reset(size_t n) + { + RUNTIME_CHECK(n <= UINT32_MAX); + probe_rows.clear(); + probe_rows.resize(n + 1); + // Sentinel circular self-loop + probe_rows[0].next_node = 0; + probe_rows[0].prev_node = 0; + } + + inline size_t size() { return probe_rows.size() - 1; } + + /// Append an existing slot by index at the tail (before sentinel) + inline void append(index_t idx) + { + index_t tail = probe_rows[0].prev_node; + probe_rows[tail].next_node = idx; + probe_rows[idx].prev_node = tail; + probe_rows[idx].next_node = 0; + probe_rows[0].prev_node = idx; + } + + /// Remove a slot by index from the list + inline void remove(index_t idx) + { + index_t prev = probe_rows[idx].prev_node; + index_t next = probe_rows[idx].next_node; + probe_rows[prev].next_node = next; + probe_rows[next].prev_node = prev; + } + + Iterator begin() { return Iterator(this, probe_rows[0].next_node); } + Iterator end() { return Iterator(this, 0); } + + PendingProbeRow & operator[](index_t idx) { return probe_rows[idx]; } + const PendingProbeRow & operator[](index_t idx) const { return probe_rows[idx]; } + +private: + PaddedPODArray probe_rows; +}; + +class HashJoin; +class SemiJoinProbeHelper : public JoinProbeHelperUtil +{ +public: + explicit SemiJoinProbeHelper(const HashJoin * join); + + Block probe(JoinProbeContext & context, JoinProbeWorkerData & wd); + +private: + template + Block probeImpl(JoinProbeContext & context, JoinProbeWorkerData & wd); + +private: + using FuncType = Block (SemiJoinProbeHelper::*)(JoinProbeContext &, JoinProbeWorkerData &); + FuncType func_ptr_has_null = nullptr; + FuncType func_ptr_no_null = nullptr; + + const HashJoin * join; + const HashJoinPointerTable & pointer_table; +}; + + +} // namespace DB From 6024b7214ec7bf31be55f138c3b711ec0503a6db Mon Sep 17 00:00:00 2001 From: gengliqi Date: Wed, 23 Apr 2025 11:32:49 +0800 Subject: [PATCH 42/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index acdce9877f0..594b457d2e8 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -836,9 +836,7 @@ void JoinProbeHelper::probeFillColumnsPrefetch( } if constexpr (Adder::break_on_first_match) - { next_ptr = nullptr; - } } if (next_ptr) From 2627cd5ce715e643c712fde808c85d60c6194709 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Wed, 23 Apr 2025 23:51:11 +0800 Subject: [PATCH 43/84] tiny refine Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoin.cpp | 10 +- dbms/src/Interpreters/JoinV2/HashJoin.h | 2 +- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 255 ++++++++++-------- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 19 +- 4 files changed, 156 insertions(+), 130 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index 8f9a71478b3..5dfb34755f6 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -611,7 +611,7 @@ bool HashJoin::buildPointerTable(size_t stream_index) return is_end; } -Block HashJoin::probeBlock(JoinProbeContext & context, size_t stream_index) +Block HashJoin::probeBlock(JoinProbeContext & ctx, size_t stream_index) { RUNTIME_ASSERT(stream_index < probe_concurrency); RUNTIME_CHECK_MSG(probe_initialized, "Logical error: Join probe was not initialized"); @@ -622,7 +622,7 @@ Block HashJoin::probeBlock(JoinProbeContext & context, size_t stream_index) const NameSet & probe_output_name_set = has_other_condition ? output_columns_names_set_for_other_condition_after_finalize : output_column_names_set_after_finalize; - context.prepareForHashProbe( + ctx.prepareForHashProbe( method, kind, has_other_condition, @@ -636,9 +636,9 @@ Block HashJoin::probeBlock(JoinProbeContext & context, size_t stream_index) FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_join_prob_failpoint); auto & wd = probe_workers_data[stream_index]; - Block res = join_probe_helper->probe(context, wd); - if (context.isAllFinished()) - wd.probe_handle_rows += context.rows; + Block res = join_probe_helper->probe(ctx, wd); + if (ctx.isAllFinished()) + wd.probe_handle_rows += ctx.rows; return res; } diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.h b/dbms/src/Interpreters/JoinV2/HashJoin.h index 077341a4dbd..c62047db4a4 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.h +++ b/dbms/src/Interpreters/JoinV2/HashJoin.h @@ -57,7 +57,7 @@ class HashJoin void buildRowFromBlock(const Block & block, size_t stream_index); bool buildPointerTable(size_t stream_index); - Block probeBlock(JoinProbeContext & context, size_t stream_index); + Block probeBlock(JoinProbeContext & ctx, size_t stream_index); Block probeLastResultBlock(size_t stream_index); void removeUselessColumn(Block & block) const; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 594b457d2e8..d8de077f7d2 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -115,10 +115,14 @@ void JoinProbeContext::prepareForHashProbe( if (kind == LeftOuter && has_other_condition) { rows_not_matched.clear(); - rows_not_matched.resize_fill(block.rows(), 1); + rows_not_matched.resize_fill(rows, 1); not_matched_offsets_idx = -1; not_matched_offsets.clear(); } + if (kind == LeftOuterSemi || kind == LeftOuterAnti) + { + semi_match_res.resize(rows); + } is_prepared = true; } @@ -197,6 +201,7 @@ struct JoinProbeAdder static bool ALWAYS_INLINE addMatched( JoinProbeHelper & helper, + JoinProbeContext &, JoinProbeWorkerData & wd, MutableColumns & added_columns, size_t idx, @@ -210,7 +215,8 @@ struct JoinProbeAdder return current_offset >= helper.settings.max_block_size; } - static bool ALWAYS_INLINE addNotMatched(JoinProbeHelper &, JoinProbeWorkerData &, size_t, size_t &) + static bool ALWAYS_INLINE + addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, size_t, size_t &) { return false; } @@ -231,6 +237,7 @@ struct JoinProbeAdder static bool ALWAYS_INLINE addMatched( JoinProbeHelper & helper, + JoinProbeContext &, JoinProbeWorkerData & wd, MutableColumns & added_columns, size_t idx, @@ -244,8 +251,12 @@ struct JoinProbeAdder return current_offset >= helper.settings.max_block_size; } - static bool ALWAYS_INLINE - addNotMatched(JoinProbeHelper & helper, JoinProbeWorkerData & wd, size_t idx, size_t & current_offset) + static bool ALWAYS_INLINE addNotMatched( + JoinProbeHelper & helper, + JoinProbeContext &, + JoinProbeWorkerData & wd, + size_t idx, + size_t & current_offset) { if constexpr (!has_other_condition) { @@ -286,6 +297,7 @@ struct JoinProbeAdder static bool ALWAYS_INLINE addMatched( JoinProbeHelper & helper, + JoinProbeContext &, JoinProbeWorkerData & wd, MutableColumns &, size_t idx, @@ -298,7 +310,8 @@ struct JoinProbeAdder return current_offset >= helper.settings.max_block_size; } - static bool ALWAYS_INLINE addNotMatched(JoinProbeHelper &, JoinProbeWorkerData &, size_t, size_t &) + static bool ALWAYS_INLINE + addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, size_t, size_t &) { return false; } @@ -313,14 +326,25 @@ struct JoinProbeAdder static constexpr bool need_not_matched = true; static constexpr bool break_on_first_match = true; - static bool ALWAYS_INLINE - addMatched(JoinProbeHelper &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &, RowPtr, size_t) + static bool ALWAYS_INLINE addMatched( + JoinProbeHelper &, + JoinProbeContext &, + JoinProbeWorkerData &, + MutableColumns &, + size_t, + size_t &, + RowPtr, + size_t) { return false; } - static bool ALWAYS_INLINE - addNotMatched(JoinProbeHelper & helper, JoinProbeWorkerData & wd, size_t idx, size_t & current_offset) + static bool ALWAYS_INLINE addNotMatched( + JoinProbeHelper & helper, + JoinProbeContext &, + JoinProbeWorkerData & wd, + size_t idx, + size_t & current_offset) { ++current_offset; wd.selective_offsets.push_back(idx); @@ -337,14 +361,22 @@ struct JoinProbeAdder static constexpr bool need_not_matched = false; static constexpr bool break_on_first_match = true; - static bool ALWAYS_INLINE - addMatched(JoinProbeHelper &, JoinProbeWorkerData & wd, MutableColumns &, size_t idx, size_t &, RowPtr, size_t) + static bool ALWAYS_INLINE addMatched( + JoinProbeHelper &, + JoinProbeContext & ctx, + JoinProbeWorkerData &, + MutableColumns &, + size_t idx, + size_t &, + RowPtr, + size_t) { - wd.match_helper_res[idx] = 1; + ctx.semi_match_res[idx] = 1; return false; } - static bool ALWAYS_INLINE addNotMatched(JoinProbeHelper &, JoinProbeWorkerData &, size_t, size_t &) + static bool ALWAYS_INLINE + addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, size_t, size_t &) { return false; } @@ -365,9 +397,10 @@ struct JoinProbeAdder return false; } - static bool ALWAYS_INLINE addNotMatched(JoinProbeHelper &, JoinProbeWorkerData & wd, size_t idx, size_t &) + static bool ALWAYS_INLINE + addNotMatched(JoinProbeHelper &, JoinProbeContext & ctx, JoinProbeWorkerData &, size_t idx, size_t &) { - wd.match_helper_res[idx] = 1; + ctx.semi_match_res[idx] = 1; return false; } @@ -453,26 +486,26 @@ JoinProbeHelper::JoinProbeHelper(const HashJoin * join, bool late_materializatio #undef CALL3 } -Block JoinProbeHelper::probe(JoinProbeContext & context, JoinProbeWorkerData & wd) +Block JoinProbeHelper::probe(JoinProbeContext & ctx, JoinProbeWorkerData & wd) { - if (context.null_map) - return (this->*func_ptr_has_null)(context, wd); + if (ctx.null_map) + return (this->*func_ptr_has_null)(ctx, wd); else - return (this->*func_ptr_no_null)(context, wd); + return (this->*func_ptr_no_null)(ctx, wd); } JOIN_PROBE_HELPER_TEMPLATE -Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData & wd) +Block JoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData & wd) { static_assert(has_other_condition || !late_materialization); - if unlikely (context.rows == 0) + if unlikely (ctx.rows == 0) return join->output_block_after_finalize; if constexpr (kind == LeftOuter && has_other_condition) { - if (context.isProbeFinished()) - return fillNotMatchedRowsForLeftOuter(context, wd); + if (ctx.isProbeFinished()) + return fillNotMatchedRowsForLeftOuter(ctx, wd); } wd.insert_batch.clear(); @@ -484,11 +517,6 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData wd.not_matched_selective_offsets.clear(); wd.not_matched_selective_offsets.reserve(settings.max_block_size); } - if constexpr ((kind == LeftOuterSemi || kind == LeftOuterAnti) && !has_other_condition) - { - wd.match_helper_res.clear(); - wd.match_helper_res.resize_fill_zero(context.rows); - } if constexpr (late_materialization) { wd.row_ptrs_for_lm.clear(); @@ -520,10 +548,10 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData has_null_map, has_other_condition, late_materialization, - tagged_pointer>(context, wd, added_columns); + tagged_pointer>(ctx, wd, added_columns); else probeFillColumns( - context, + ctx, wd, added_columns); wd.probe_hash_table_time += watch.elapsedFromLastTime(); @@ -541,7 +569,7 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData if constexpr (kind == LeftOuterSemi || kind == LeftOuterAnti) { - return genResultBlockForLeftOuterSemi(context, wd); + return genResultBlockForLeftOuterSemi(ctx); } if constexpr (has_other_condition) @@ -552,7 +580,7 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData if (!join->left_required_flag_for_other_condition[i]) continue; wd.result_block.safeGetByPosition(i).column->assumeMutable()->insertSelectiveFrom( - *context.block.safeGetByPosition(i).column.get(), + *ctx.block.safeGetByPosition(i).column.get(), wd.selective_offsets); } } @@ -561,7 +589,7 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData for (size_t i = 0; i < left_columns; ++i) { wd.result_block.safeGetByPosition(i).column->assumeMutable()->insertSelectiveFrom( - *context.block.safeGetByPosition(i).column.get(), + *ctx.block.safeGetByPosition(i).column.get(), wd.selective_offsets); } } @@ -570,7 +598,7 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData if constexpr (has_other_condition) { - auto res_block = handleOtherConditions(context, wd, kind, late_materialization); + auto res_block = handleOtherConditions(ctx, wd, kind, late_materialization); wd.other_condition_time += watch.elapsedFromLastTime(); return res_block; } @@ -585,21 +613,18 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorkerData } JOIN_PROBE_HELPER_TEMPLATE -void JoinProbeHelper::probeFillColumns( - JoinProbeContext & context, - JoinProbeWorkerData & wd, - MutableColumns & added_columns) +void JoinProbeHelper::probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns) { using KeyGetterType = typename KeyGetter::Type; using Hash = typename KeyGetter::Hash; using HashValueType = typename KeyGetter::HashValueType; using Adder = JoinProbeAdder; - auto & key_getter = *static_cast(context.key_getter.get()); + auto & key_getter = *static_cast(ctx.key_getter.get()); size_t current_offset = wd.result_block.rows(); - size_t idx = context.current_row_idx; - RowPtr ptr = context.current_build_row_ptr; - bool is_matched = context.current_row_is_matched; + size_t idx = ctx.current_row_idx; + RowPtr ptr = ctx.current_build_row_ptr; + bool is_matched = ctx.current_row_is_matched; size_t collision = 0; size_t key_offset = sizeof(RowPtr); if constexpr (KeyGetterType::joinKeyCompareHashFirst()) @@ -607,26 +632,26 @@ void JoinProbeHelper::probeFillColumns( key_offset += sizeof(HashValueType); } -#define NOT_MATCHED(not_matched) \ - if constexpr (Adder::need_not_matched) \ - { \ - assert(ptr == nullptr); \ - if (not_matched) \ - { \ - bool is_end = Adder::addNotMatched(*this, wd, idx, current_offset); \ - if unlikely (is_end) \ - { \ - ++idx; \ - break; \ - } \ - } \ +#define NOT_MATCHED(not_matched) \ + if constexpr (Adder::need_not_matched) \ + { \ + assert(ptr == nullptr); \ + if (not_matched) \ + { \ + bool is_end = Adder::addNotMatched(*this, ctx, wd, idx, current_offset); \ + if unlikely (is_end) \ + { \ + ++idx; \ + break; \ + } \ + } \ } - for (; idx < context.rows; ++idx) + for (; idx < ctx.rows; ++idx) { if constexpr (has_null_map) { - if ((*context.null_map)[idx]) + if ((*ctx.null_map)[idx]) { NOT_MATCHED(true) continue; @@ -672,6 +697,7 @@ void JoinProbeHelper::probeFillColumns( { bool is_end = Adder::addMatched( *this, + ctx, wd, added_columns, idx, @@ -710,9 +736,9 @@ void JoinProbeHelper::probeFillColumns( Adder::flush(*this, wd, added_columns); - context.current_row_idx = idx; - context.current_build_row_ptr = ptr; - context.current_row_is_matched = is_matched; + ctx.current_row_idx = idx; + ctx.current_build_row_ptr = ptr; + ctx.current_row_is_matched = is_matched; wd.collision += collision; #undef NOT_MATCHED @@ -752,7 +778,7 @@ struct ProbePrefetchState JOIN_PROBE_HELPER_TEMPLATE void JoinProbeHelper::probeFillColumnsPrefetch( - JoinProbeContext & context, + JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns) { @@ -761,18 +787,18 @@ void JoinProbeHelper::probeFillColumnsPrefetch( using HashValueType = typename KeyGetter::HashValueType; using Adder = JoinProbeAdder; - auto & key_getter = *static_cast(context.key_getter.get()); - if (!context.prefetch_states) + auto & key_getter = *static_cast(ctx.key_getter.get()); + if (!ctx.prefetch_states) { - context.prefetch_states = decltype(context.prefetch_states)( + ctx.prefetch_states = decltype(ctx.prefetch_states)( static_cast(new ProbePrefetchState[settings.probe_prefetch_step]), [](void * ptr) { delete[] static_cast *>(ptr); }); } - auto * states = static_cast *>(context.prefetch_states.get()); + auto * states = static_cast *>(ctx.prefetch_states.get()); - size_t idx = context.current_row_idx; - size_t active_states = context.prefetch_active_states; - size_t k = context.prefetch_iter; + size_t idx = ctx.current_row_idx; + size_t active_states = ctx.prefetch_active_states; + size_t k = ctx.prefetch_iter; size_t current_offset = wd.result_block.rows(); size_t collision = 0; size_t key_offset = sizeof(RowPtr); @@ -781,19 +807,19 @@ void JoinProbeHelper::probeFillColumnsPrefetch( key_offset += sizeof(HashValueType); } -#define NOT_MATCHED(not_matched, idx) \ - if constexpr (Adder::need_not_matched) \ - { \ - if (not_matched) \ - { \ - bool is_end = Adder::addNotMatched(*this, wd, idx, current_offset); \ - if unlikely (is_end) \ - break; \ - } \ +#define NOT_MATCHED(not_matched, idx) \ + if constexpr (Adder::need_not_matched) \ + { \ + if (not_matched) \ + { \ + bool is_end = Adder::addNotMatched(*this, ctx, wd, idx, current_offset); \ + if unlikely (is_end) \ + break; \ + } \ } const size_t probe_prefetch_step = settings.probe_prefetch_step; - while (idx < context.rows || active_states > 0) + while (idx < ctx.rows || active_states > 0) { k = k == probe_prefetch_step ? 0 : k; auto * state = &states[k]; @@ -815,6 +841,7 @@ void JoinProbeHelper::probeFillColumnsPrefetch( { bool is_end = Adder::addMatched( *this, + ctx, wd, added_columns, state->index, @@ -885,14 +912,14 @@ void JoinProbeHelper::probeFillColumnsPrefetch( if constexpr (has_null_map) { bool is_end = false; - while (idx < context.rows) + while (idx < ctx.rows) { - if (!(*context.null_map)[idx]) + if (!(*ctx.null_map)[idx]) break; if constexpr (Adder::need_not_matched) { - is_end = Adder::addNotMatched(*this, wd, idx, current_offset); + is_end = Adder::addNotMatched(*this, ctx, wd, idx, current_offset); if unlikely (is_end) { ++idx; @@ -909,7 +936,7 @@ void JoinProbeHelper::probeFillColumnsPrefetch( } } - if unlikely (idx >= context.rows) + if unlikely (idx >= ctx.rows) { ++k; continue; @@ -937,16 +964,16 @@ void JoinProbeHelper::probeFillColumnsPrefetch( Adder::flush(*this, wd, added_columns); - context.current_row_idx = idx; - context.prefetch_active_states = active_states; - context.prefetch_iter = k; + ctx.current_row_idx = idx; + ctx.prefetch_active_states = active_states; + ctx.prefetch_iter = k; wd.collision += collision; #undef NOT_MATCHED } Block JoinProbeHelper::handleOtherConditions( - JoinProbeContext & context, + JoinProbeContext & ctx, JoinProbeWorkerData & wd, ASTTableJoin::Kind kind, bool late_materialization) @@ -1007,7 +1034,7 @@ Block JoinProbeHelper::handleOtherConditions( { size_t idx = wd.selective_offsets[i]; bool is_matched = wd.filter[i]; - context.rows_not_matched[idx] &= !is_matched; + ctx.rows_not_matched[idx] &= !is_matched; } } @@ -1155,7 +1182,7 @@ Block JoinProbeHelper::handleOtherConditions( } else { - auto & src_column = context.block.safeGetByPosition(i); + auto & src_column = ctx.block.safeGetByPosition(i); des_column.column->assumeMutable()->insertSelectiveRangeFrom( *src_column.column.get(), wd.result_block_filter_offsets, @@ -1166,7 +1193,7 @@ Block JoinProbeHelper::handleOtherConditions( } if unlikely (!result_block_filter_offsets_is_initialized) init_result_block_filter_offsets(); - auto & src_column = context.block.safeGetByPosition(i); + auto & src_column = ctx.block.safeGetByPosition(i); des_column.column->assumeMutable() ->insertSelectiveRangeFrom(*src_column.column.get(), wd.result_block_filter_offsets, start, length); } @@ -1198,48 +1225,48 @@ Block JoinProbeHelper::handleOtherConditions( return res_block; } - if (kind == LeftOuter && context.isProbeFinished()) - return fillNotMatchedRowsForLeftOuter(context, wd); + if (kind == LeftOuter && ctx.isProbeFinished()) + return fillNotMatchedRowsForLeftOuter(ctx, wd); return output_block_after_finalize; } -Block JoinProbeHelper::fillNotMatchedRowsForLeftOuter(JoinProbeContext & context, JoinProbeWorkerData & wd) +Block JoinProbeHelper::fillNotMatchedRowsForLeftOuter(JoinProbeContext & ctx, JoinProbeWorkerData & wd) { RUNTIME_CHECK(join->kind == LeftOuter); RUNTIME_CHECK(join->has_other_condition); - RUNTIME_CHECK(context.isProbeFinished()); - if (context.not_matched_offsets_idx < 0) + RUNTIME_CHECK(ctx.isProbeFinished()); + if (ctx.not_matched_offsets_idx < 0) { - size_t rows = context.rows; - size_t not_matched_result_size = countBytesInFilter(context.rows_not_matched); - auto & offsets = context.not_matched_offsets; + size_t rows = ctx.rows; + size_t not_matched_result_size = countBytesInFilter(ctx.rows_not_matched); + auto & offsets = ctx.not_matched_offsets; offsets.clear(); offsets.reserve(not_matched_result_size); if likely (rows <= BASE_OFFSETS.size()) { - filterImpl(&context.rows_not_matched[0], &context.rows_not_matched[rows], &BASE_OFFSETS[0], offsets); + filterImpl(&ctx.rows_not_matched[0], &ctx.rows_not_matched[rows], &BASE_OFFSETS[0], offsets); RUNTIME_CHECK(offsets.size() == not_matched_result_size); } else { for (size_t i = 0; i < rows; ++i) { - if (context.rows_not_matched[i]) + if (ctx.rows_not_matched[i]) offsets.push_back(i); } } - context.not_matched_offsets_idx = 0; + ctx.not_matched_offsets_idx = 0; } const auto & output_block_after_finalize = join->output_block_after_finalize; - if (static_cast(context.not_matched_offsets_idx) >= context.not_matched_offsets.size()) + if (static_cast(ctx.not_matched_offsets_idx) >= ctx.not_matched_offsets.size()) { // JoinProbeContext::isAllFinished checks if all not matched rows have been output // by verifying whether rows_not_matched is empty. - context.rows_not_matched.clear(); + ctx.rows_not_matched.clear(); return output_block_after_finalize; } @@ -1256,7 +1283,7 @@ Block JoinProbeHelper::fillNotMatchedRowsForLeftOuter(JoinProbeContext & context } size_t remaining_insert_size = settings.max_block_size - wd.result_block_for_other_condition.rows(); - size_t result_size = context.not_matched_offsets.size() - context.not_matched_offsets_idx; + size_t result_size = ctx.not_matched_offsets.size() - ctx.not_matched_offsets_idx; size_t length = std::min(result_size, remaining_insert_size); const auto & output_column_indexes = join->output_column_indexes; @@ -1275,20 +1302,20 @@ Block JoinProbeHelper::fillNotMatchedRowsForLeftOuter(JoinProbeContext & context if (output_index < 0) continue; auto & des_column = wd.result_block_for_other_condition.safeGetByPosition(output_index); - auto & src_column = context.block.safeGetByPosition(i); + auto & src_column = ctx.block.safeGetByPosition(i); des_column.column->assumeMutable()->insertSelectiveRangeFrom( *src_column.column.get(), - context.not_matched_offsets, - context.not_matched_offsets_idx, + ctx.not_matched_offsets, + ctx.not_matched_offsets_idx, length); } - context.not_matched_offsets_idx += length; + ctx.not_matched_offsets_idx += length; - if (static_cast(context.not_matched_offsets_idx) >= context.not_matched_offsets.size()) + if (static_cast(ctx.not_matched_offsets_idx) >= ctx.not_matched_offsets.size()) { // JoinProbeContext::isAllFinished checks if all not matched rows have been output // by verifying whether rows_not_matched is empty. - context.rows_not_matched.clear(); + ctx.rows_not_matched.clear(); } Block res_block; @@ -1296,11 +1323,11 @@ Block JoinProbeHelper::fillNotMatchedRowsForLeftOuter(JoinProbeContext & context return res_block; } -Block JoinProbeHelper::genResultBlockForLeftOuterSemi(JoinProbeContext & context, JoinProbeWorkerData & wd) +Block JoinProbeHelper::genResultBlockForLeftOuterSemi(JoinProbeContext & ctx) { RUNTIME_CHECK(join->kind == LeftOuterSemi || join->kind == LeftOuterAnti); RUNTIME_CHECK(!join->has_other_condition); - RUNTIME_CHECK(context.isProbeFinished()); + RUNTIME_CHECK(ctx.isProbeFinished()); Block res_block = join->output_block_after_finalize.cloneEmpty(); size_t columns = res_block.columns(); @@ -1309,14 +1336,14 @@ Block JoinProbeHelper::genResultBlockForLeftOuterSemi(JoinProbeContext & context { if (i == match_helper_column_index) continue; - res_block.getByPosition(i) = context.block.getByName(res_block.getByPosition(i).name); + res_block.getByPosition(i) = ctx.block.getByName(res_block.getByPosition(i).name); } MutableColumnPtr match_helper_column_ptr = res_block.getByPosition(match_helper_column_index).column->cloneEmpty(); auto * match_helper_column = typeid_cast(match_helper_column_ptr.get()); - match_helper_column->getNullMapColumn().getData().resize_fill_zero(context.rows); + match_helper_column->getNullMapColumn().getData().resize_fill_zero(ctx.rows); auto * match_helper_res = &typeid_cast &>(match_helper_column->getNestedColumn()).getData(); - match_helper_res->swap(wd.match_helper_res); + match_helper_res->swap(ctx.semi_match_res); res_block.getByPosition(match_helper_column_index).column = std::move(match_helper_column_ptr); diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 672aa384820..c24a087822e 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -45,6 +45,8 @@ struct JoinProbeContext /// < 0 means not_matched_offsets is not initialized. ssize_t not_matched_offsets_idx = -1; IColumn::Offsets not_matched_offsets; + /// For left outer (anti) semi join. + PaddedPODArray semi_match_res; size_t prefetch_active_states = 0; size_t prefetch_iter = 0; @@ -80,8 +82,6 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData IColumn::Offsets selective_offsets; /// For left outer join without other conditions IColumn::Offsets not_matched_selective_offsets; - /// For left outer (anti) semi join without other conditions - PaddedPODArray match_helper_res; RowPtrs insert_batch; @@ -169,28 +169,27 @@ class JoinProbeHelper : public JoinProbeHelperUtil public: JoinProbeHelper(const HashJoin * join, bool late_materialization); - Block probe(JoinProbeContext & context, JoinProbeWorkerData & wd); + Block probe(JoinProbeContext & ctx, JoinProbeWorkerData & wd); private: JOIN_PROBE_HELPER_TEMPLATE - Block probeImpl(JoinProbeContext & context, JoinProbeWorkerData & wd); + Block probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData & wd); JOIN_PROBE_HELPER_TEMPLATE - void NO_INLINE - probeFillColumns(JoinProbeContext & context, JoinProbeWorkerData & wd, MutableColumns & added_columns); + void NO_INLINE probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns); JOIN_PROBE_HELPER_TEMPLATE void NO_INLINE - probeFillColumnsPrefetch(JoinProbeContext & context, JoinProbeWorkerData & wd, MutableColumns & added_columns); + probeFillColumnsPrefetch(JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns); Block handleOtherConditions( - JoinProbeContext & context, + JoinProbeContext & ctx, JoinProbeWorkerData & wd, ASTTableJoin::Kind kind, bool late_materialization); - Block fillNotMatchedRowsForLeftOuter(JoinProbeContext & context, JoinProbeWorkerData & wd); + Block fillNotMatchedRowsForLeftOuter(JoinProbeContext & ctx, JoinProbeWorkerData & wd); - Block genResultBlockForLeftOuterSemi(JoinProbeContext & context, JoinProbeWorkerData & wd); + Block genResultBlockForLeftOuterSemi(JoinProbeContext & ctx); private: template From ff2e01e21bfa7c4d438bede64def59b68231df4b Mon Sep 17 00:00:00 2001 From: gengliqi Date: Wed, 23 Apr 2025 23:53:32 +0800 Subject: [PATCH 44/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoin.cpp | 12 +- dbms/src/Interpreters/JoinV2/HashJoin.h | 1 + .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 5 +- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 4 + .../src/Interpreters/JoinV2/SemiJoinProbe.cpp | 221 ++++++++++++++++++ dbms/src/Interpreters/JoinV2/SemiJoinProbe.h | 102 +++++--- 6 files changed, 310 insertions(+), 35 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index 8f9a71478b3..ad7f04691d5 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -28,6 +28,7 @@ #include #include #include +#include "Interpreters/JoinV2/SemiJoinProbe.h" namespace DB { @@ -498,7 +499,10 @@ void HashJoin::workAfterBuildRowFinish() fiu_do_on(FailPoints::force_join_v2_probe_enable_lm, { late_materialization = true; }); fiu_do_on(FailPoints::force_join_v2_probe_disable_lm, { late_materialization = false; }); - join_probe_helper = std::make_unique(this, late_materialization); + if (SemiJoinProbeHelper::isSupported(kind, has_other_condition)) + semi_join_probe_helper = std::make_unique(this); + else + join_probe_helper = std::make_unique(this, late_materialization); LOG_INFO( log, @@ -636,7 +640,11 @@ Block HashJoin::probeBlock(JoinProbeContext & context, size_t stream_index) FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_join_prob_failpoint); auto & wd = probe_workers_data[stream_index]; - Block res = join_probe_helper->probe(context, wd); + Block res; + if (semi_join_probe_helper) + res = semi_join_probe_helper->probe(context, wd); + else + res = join_probe_helper->probe(context, wd); if (context.isAllFinished()) wd.probe_handle_rows += context.rows; return res; diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.h b/dbms/src/Interpreters/JoinV2/HashJoin.h index ffdab906747..3cf5c799419 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.h +++ b/dbms/src/Interpreters/JoinV2/HashJoin.h @@ -153,6 +153,7 @@ class HashJoin std::vector probe_workers_data; std::atomic active_probe_worker = 0; std::unique_ptr join_probe_helper; + std::unique_ptr semi_join_probe_helper; const JoinProfileInfoPtr profile_info = std::make_shared(); diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index bd23eed45a5..2644ea51506 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -58,6 +58,7 @@ void JoinProbeContext::resetBlock(Block & block_) current_row_idx = 0; current_build_row_ptr = nullptr; current_row_is_matched = false; + semi_join_pending_probe_list.reset(); prefetch_active_states = 0; @@ -121,7 +122,7 @@ void JoinProbeContext::prepareForHashProbe( not_matched_offsets.clear(); } - if ((kind == Semi || kind == Anti || kind == LeftOuterSemi || kind == LeftOuterAnti) && has_other_condition) + if (SemiJoinProbeHelper::isSupported(kind, has_other_condition)) { if unlikely (!semi_join_pending_probe_list) { @@ -129,7 +130,7 @@ void JoinProbeContext::prepareForHashProbe( static_cast(new SemiJoinPendingProbeList), [](void * ptr) { delete static_cast(ptr); }); } - static_cast(semi_join_pending_probe_list.get())->reset(block.rows() + 1); + static_cast(semi_join_pending_probe_list.get())->reset(block.rows()); } is_prepared = true; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 0e293ae98ac..84679a1df82 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -45,6 +45,10 @@ struct JoinProbeContext /// < 0 means not_matched_offsets is not initialized. ssize_t not_matched_offsets_idx = -1; IColumn::Offsets not_matched_offsets; + /// For left outer (anti) semi join. + PaddedPODArray semi_match_res; + /// For left outer (anti) semi join with null-eq-from-in conditions. + PaddedPODArray semi_match_null_res; /// For (left outer) (anti) semi join with other conditions. std::unique_ptr> semi_join_pending_probe_list; diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp index 62bfcff54dc..adcea60a80e 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp @@ -31,6 +31,35 @@ enum class SemiJoinProbeResType : UInt8 NULL_VALUE, }; +template <> +struct SemiJoinProbeAdder +{ + static constexpr bool need_matched = true; + static constexpr bool need_not_matched = false; + static constexpr bool break_on_first_match = true; + + static bool ALWAYS_INLINE addMatched( + SemiJoinProbeHelper & helper, + JoinProbeWorkerData & wd, + MutableColumns &, + size_t idx, + size_t & current_offset, + RowPtr, + size_t) + { + ++current_offset; + wd.selective_offsets.push_back(idx); + return current_offset >= helper.settings.max_block_size; + } + + static bool ALWAYS_INLINE addNotMatched(SemiJoinProbeHelper &, JoinProbeWorkerData &, size_t, size_t &) + { + return false; + } + + static void flush(SemiJoinProbeHelper &, JoinProbeWorkerData &, MutableColumns &) {} +}; + SemiJoinProbeHelper::SemiJoinProbeHelper(const HashJoin * join) : JoinProbeHelperUtil(join->settings, join->row_layout) , join(join) @@ -89,6 +118,11 @@ SemiJoinProbeHelper::SemiJoinProbeHelper(const HashJoin * join) #undef CALL2 } +bool SemiJoinProbeHelper::isSupported(ASTTableJoin::Kind kind, bool has_other_condition) +{ + return has_other_condition && (kind == Semi || kind == Anti || kind == LeftOuterSemi || kind == LeftOuterAnti); +} + Block SemiJoinProbeHelper::probe(JoinProbeContext & context, JoinProbeWorkerData & wd) { if (context.null_map) @@ -105,7 +139,194 @@ Block SemiJoinProbeHelper::probeImpl(JoinProbeContext & context, JoinProbeWorker auto * probe_list = static_cast(context.semi_join_pending_probe_list.get()); RUNTIME_CHECK(probe_list->size() == context.rows); + + size_t left_columns = join->left_sample_block_pruned.columns(); + size_t right_columns = join->right_sample_block_pruned.columns(); + if (!wd.result_block) + { + RUNTIME_CHECK(left_columns + right_columns == join->all_sample_block_pruned.columns()); + for (size_t i = 0; i < left_columns + right_columns; ++i) + { + ColumnWithTypeAndName new_column = join->all_sample_block_pruned.safeGetByPosition(i).cloneEmpty(); + new_column.column->assumeMutable()->reserveAlign(settings.max_block_size, FULL_VECTOR_SIZE_AVX2); + wd.result_block.insert(std::move(new_column)); + } + } + + MutableColumns added_columns(right_columns); + for (size_t i = 0; i < right_columns; ++i) + added_columns[i] = wd.result_block.safeGetByPosition(left_columns + i).column->assumeMutable(); + + Stopwatch watch; + if (pointer_table.enableProbePrefetch()) + { + probeFillColumnsPrefetch< + KeyGetter, + kind, + has_null_map, + tagged_pointer, true>(context, wd, added_columns); + + probeFillColumnsPrefetch< + KeyGetter, + kind, + has_null_map, + tagged_pointer, false>(context, wd, added_columns); + } + else + { + probeFillColumns( + context, + wd, + added_columns); + + probeFillColumns( + context, + wd, + added_columns); + } + wd.probe_hash_table_time += watch.elapsedFromLastTime(); + + // Move the mutable column pointers back into the wd.result_block, dropping the extra reference (ref_count 2→1). + // Alternative: added_columns.clear(); but that is less explicit and may misleadingly imply the columns are discarded. + for (size_t i = 0; i < right_columns; ++i) + wd.result_block.safeGetByPosition(left_columns + i).column = std::move(added_columns[i]); +} + +template +void SemiJoinProbeHelper::probeFillColumns(JoinProbeContext & context, JoinProbeWorkerData & wd, MutableColumns & added_columns) +{ + using KeyGetterType = typename KeyGetter::Type; + using Hash = typename KeyGetter::Hash; + using HashValueType = typename KeyGetter::HashValueType; + using Adder = SemiJoinProbeAdder; + + auto & key_getter = *static_cast(context.key_getter.get()); + size_t current_offset = wd.result_block.rows(); + size_t idx = context.current_row_idx; + RowPtr ptr = context.current_build_row_ptr; + bool is_matched = context.current_row_is_matched; + size_t collision = 0; + size_t key_offset = sizeof(RowPtr); + if constexpr (KeyGetterType::joinKeyCompareHashFirst()) + { + key_offset += sizeof(HashValueType); + } + +#define NOT_MATCHED(not_matched) \ + if constexpr (Adder::need_not_matched) \ + { \ + assert(ptr == nullptr); \ + if (not_matched) \ + { \ + bool is_end = Adder::addNotMatched(*this, wd, idx, current_offset); \ + if unlikely (is_end) \ + { \ + ++idx; \ + break; \ + } \ + } \ + } + + for (; idx < context.rows; ++idx) + { + if constexpr (has_null_map) + { + if ((*context.null_map)[idx]) + { + NOT_MATCHED(true) + continue; + } + } + + const auto & key = key_getter.getJoinKey(idx); + auto hash = static_cast(Hash()(key)); + UInt16 hash_tag = hash & ROW_PTR_TAG_MASK; + if likely (ptr == nullptr) + { + ptr = pointer_table.getHeadPointer(hash); + if (ptr == nullptr) + { + NOT_MATCHED(true) + continue; + } + + if constexpr (tagged_pointer) + { + if (!containOtherTag(ptr, hash_tag)) + { + ptr = nullptr; + NOT_MATCHED(true) + continue; + } + ptr = removeRowPtrTag(ptr); + } + if constexpr (Adder::need_not_matched) + is_matched = false; + } + while (true) + { + const auto & key2 = key_getter.deserializeJoinKey(ptr + key_offset); + bool key_is_equal = joinKeyIsEqual(key_getter, key, key2, hash, ptr); + collision += !key_is_equal; + if (key_is_equal) + { + if constexpr (Adder::need_not_matched) + is_matched = true; + + if constexpr (Adder::need_matched) + { + bool is_end = Adder::addMatched( + *this, + wd, + added_columns, + idx, + current_offset, + ptr, + key_offset + key_getter.getRequiredKeyOffset(key2)); + + if unlikely (is_end) + { + if constexpr (Adder::break_on_first_match) + ptr = nullptr; + break; + } + } + + if constexpr (Adder::break_on_first_match) + { + ptr = nullptr; + break; + } + } + + ptr = getNextRowPtr(ptr); + if (ptr == nullptr) + break; + } + if unlikely (ptr != nullptr) + { + ptr = getNextRowPtr(ptr); + if (ptr == nullptr) + ++idx; + break; + } + NOT_MATCHED(!is_matched) + } + + Adder::flush(*this, wd, added_columns); + + context.current_row_idx = idx; + context.current_build_row_ptr = ptr; + context.current_row_is_matched = is_matched; + wd.collision += collision; + +#undef NOT_MATCHED } +template +void SemiJoinProbeHelper::probeFillColumnsPrefetch(JoinProbeContext & context, JoinProbeWorkerData & wd, MutableColumns & added_columns) +{ + +} } // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h index d2675347718..45f7186e6cc 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h @@ -19,102 +19,142 @@ namespace DB { -/// A index‑based doubly‑linked circular list for managing semi join pending probe rows. -/// After reset(N), it holds N entries plus a sentinel at index 0, and supports O(1) insert/remove by index. +/// A reusable, index‑based doubly‑linked circular list for managing semi join pending probe rows. +/// Supports O(1) append/remove by index. class SemiJoinPendingProbeList { public: - using index_t = UInt32; + using IndexType = UInt32; struct PendingProbeRow { RowPtr build_row_ptr; + bool has_null_eq_from_in; UInt32 pace; - /// Embedded list pointers - UInt32 prev_node; - UInt32 next_node; + /// Embedded list indexes + IndexType prev_idx; + IndexType next_idx; }; class Iterator { public: - Iterator(SemiJoinPendingProbeList * list, index_t idx) + Iterator(SemiJoinPendingProbeList & list, IndexType idx) : list(list) , idx(idx) {} - PendingProbeRow & operator*() const { return list->probe_rows[idx]; } - PendingProbeRow * operator->() const { return &list->probe_rows[idx]; } + inline IndexType getIndex() const { return idx; } + + PendingProbeRow & operator*() const { return list.probe_rows[idx]; } + PendingProbeRow * operator->() const { return &list.probe_rows[idx]; } Iterator & operator++() { - idx = list->probe_rows[idx].next_node; + idx = list.probe_rows[idx].next_idx; return *this; } bool operator!=(const Iterator & other) const { return idx != other.idx; } private: - SemiJoinPendingProbeList * list; - index_t idx; + SemiJoinPendingProbeList & list; + IndexType idx; }; SemiJoinPendingProbeList() = default; + /// After reset(n), it holds n entries plus a sentinel at index n. void reset(size_t n) { RUNTIME_CHECK(n <= UINT32_MAX); - probe_rows.clear(); probe_rows.resize(n + 1); + sentinel_idx = static_cast(n); // Sentinel circular self-loop - probe_rows[0].next_node = 0; - probe_rows[0].prev_node = 0; + probe_rows[sentinel_idx].prev_idx = sentinel_idx; + probe_rows[sentinel_idx].next_idx = sentinel_idx; + +#ifndef NDEBUG + // reset should be called after all slots are removed. + assert(slot_count == 0); + // Isolate all slots + for (IndexType i = 0; i < n; ++i) { + probe_rows[i].prev_idx = i; + probe_rows[i].next_idx = i; + } +#endif } inline size_t size() { return probe_rows.size() - 1; } /// Append an existing slot by index at the tail (before sentinel) - inline void append(index_t idx) + inline void append(IndexType idx) { - index_t tail = probe_rows[0].prev_node; - probe_rows[tail].next_node = idx; - probe_rows[idx].prev_node = tail; - probe_rows[idx].next_node = 0; - probe_rows[0].prev_node = idx; +#ifndef NDEBUG + assert(idx < size()); + assert(probe_rows[idx].prev_idx == idx && probe_rows[idx].next_idx == idx); + ++slot_count; +#endif + IndexType tail = probe_rows[sentinel_idx].prev_idx; + probe_rows[tail].next_idx = idx; + probe_rows[idx].prev_idx = tail; + probe_rows[idx].next_idx = sentinel_idx; + probe_rows[sentinel_idx].prev_idx = idx; } /// Remove a slot by index from the list - inline void remove(index_t idx) + inline void remove(IndexType idx) { - index_t prev = probe_rows[idx].prev_node; - index_t next = probe_rows[idx].next_node; - probe_rows[prev].next_node = next; - probe_rows[next].prev_node = prev; +#ifndef NDEBUG + assert(idx < size()); + assert(probe_rows[idx].prev_idx != idx && probe_rows[idx].next_idx != idx); + assert(slot_count > 0); + --slot_count; +#endif + IndexType prev = probe_rows[idx].prev_idx; + IndexType next = probe_rows[idx].next_idx; + probe_rows[prev].next_idx = next; + probe_rows[next].prev_idx = prev; } - Iterator begin() { return Iterator(this, probe_rows[0].next_node); } - Iterator end() { return Iterator(this, 0); } - - PendingProbeRow & operator[](index_t idx) { return probe_rows[idx]; } - const PendingProbeRow & operator[](index_t idx) const { return probe_rows[idx]; } + Iterator begin() { return Iterator(*this, probe_rows[sentinel_idx].next_idx); } + Iterator end() { return Iterator(*this, sentinel_idx); } private: PaddedPODArray probe_rows; + IndexType sentinel_idx = 0; +#ifndef NDEBUG + size_t slot_count = 0; +#endif }; +template +struct SemiJoinProbeAdder; + class HashJoin; class SemiJoinProbeHelper : public JoinProbeHelperUtil { public: explicit SemiJoinProbeHelper(const HashJoin * join); + static bool isSupported(ASTTableJoin::Kind kind, bool has_other_condition); + Block probe(JoinProbeContext & context, JoinProbeWorkerData & wd); private: template Block probeImpl(JoinProbeContext & context, JoinProbeWorkerData & wd); + template + void NO_INLINE probeFillColumns(JoinProbeContext & context, JoinProbeWorkerData & wd, MutableColumns & added_columns); + + template + void NO_INLINE probeFillColumnsPrefetch(JoinProbeContext & context, JoinProbeWorkerData & wd, MutableColumns & added_columns); + private: + template + friend struct SemiJoinProbeAdder; + using FuncType = Block (SemiJoinProbeHelper::*)(JoinProbeContext &, JoinProbeWorkerData &); FuncType func_ptr_has_null = nullptr; FuncType func_ptr_no_null = nullptr; From 8f681c1a06c3068629c7b7f73504f039fd179397 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 24 Apr 2025 00:30:31 +0800 Subject: [PATCH 45/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index d8de077f7d2..5c4af9915bc 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -507,6 +507,11 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData & w if (ctx.isProbeFinished()) return fillNotMatchedRowsForLeftOuter(ctx, wd); } + if constexpr (kind == LeftOuterSemi || kind == LeftOuterAnti) + { + // Sanity check + RUNTIME_CHECK(ctx.semi_match_res.size() == ctx.rows); + } wd.insert_batch.clear(); wd.insert_batch.reserve(settings.probe_insert_batch_size); From 97b2b703146aa0fa943c098e1cbede2f3c3ad414 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 24 Apr 2025 01:51:38 +0800 Subject: [PATCH 46/84] fix Signed-off-by: gengliqi --- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 27 ++++++++++++------- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 2 +- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 5c4af9915bc..5c83f34e37c 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -120,9 +120,7 @@ void JoinProbeContext::prepareForHashProbe( not_matched_offsets.clear(); } if (kind == LeftOuterSemi || kind == LeftOuterAnti) - { - semi_match_res.resize(rows); - } + left_semi_match_res.resize(rows); is_prepared = true; } @@ -371,13 +369,14 @@ struct JoinProbeAdder RowPtr, size_t) { - ctx.semi_match_res[idx] = 1; + ctx.left_semi_match_res[idx] = 1; return false; } static bool ALWAYS_INLINE - addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, size_t, size_t &) + addNotMatched(JoinProbeHelper &, JoinProbeContext & ctx, JoinProbeWorkerData &, size_t idx, size_t &) { + ctx.left_semi_match_res[idx] = 0; return false; } @@ -391,16 +390,24 @@ struct JoinProbeAdder static constexpr bool need_not_matched = true; static constexpr bool break_on_first_match = true; - static bool ALWAYS_INLINE - addMatched(JoinProbeHelper &, JoinProbeWorkerData &, MutableColumns &, size_t, size_t &, RowPtr, size_t) + static bool ALWAYS_INLINE addMatched( + JoinProbeHelper &, + JoinProbeContext & ctx, + JoinProbeWorkerData &, + MutableColumns &, + size_t idx, + size_t &, + RowPtr, + size_t) { + ctx.left_semi_match_res[idx] = 0; return false; } static bool ALWAYS_INLINE addNotMatched(JoinProbeHelper &, JoinProbeContext & ctx, JoinProbeWorkerData &, size_t idx, size_t &) { - ctx.semi_match_res[idx] = 1; + ctx.left_semi_match_res[idx] = 1; return false; } @@ -510,7 +517,7 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData & w if constexpr (kind == LeftOuterSemi || kind == LeftOuterAnti) { // Sanity check - RUNTIME_CHECK(ctx.semi_match_res.size() == ctx.rows); + RUNTIME_CHECK(ctx.left_semi_match_res.size() == ctx.rows); } wd.insert_batch.clear(); @@ -1348,7 +1355,7 @@ Block JoinProbeHelper::genResultBlockForLeftOuterSemi(JoinProbeContext & ctx) auto * match_helper_column = typeid_cast(match_helper_column_ptr.get()); match_helper_column->getNullMapColumn().getData().resize_fill_zero(ctx.rows); auto * match_helper_res = &typeid_cast &>(match_helper_column->getNestedColumn()).getData(); - match_helper_res->swap(ctx.semi_match_res); + match_helper_res->swap(ctx.left_semi_match_res); res_block.getByPosition(match_helper_column_index).column = std::move(match_helper_column_ptr); diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index c24a087822e..97502648e2c 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -46,7 +46,7 @@ struct JoinProbeContext ssize_t not_matched_offsets_idx = -1; IColumn::Offsets not_matched_offsets; /// For left outer (anti) semi join. - PaddedPODArray semi_match_res; + PaddedPODArray left_semi_match_res; size_t prefetch_active_states = 0; size_t prefetch_iter = 0; From e1c050f81fbb14dae987493e1cb4332fcd1de6de Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 24 Apr 2025 01:52:19 +0800 Subject: [PATCH 47/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoin.cpp | 1 + dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 18 +++++++++++++----- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 9 ++++++--- dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp | 14 ++++++++++++++ dbms/src/Interpreters/JoinV2/SemiJoinProbe.h | 12 ++++++------ 5 files changed, 40 insertions(+), 14 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index 5dd5f062936..305be3a2c99 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -629,6 +629,7 @@ Block HashJoin::probeBlock(JoinProbeContext & ctx, size_t stream_index) method, kind, has_other_condition, + !non_equal_conditions.other_eq_cond_from_in_name.empty(), key_names_left, non_equal_conditions.left_filter_column, probe_output_name_set, diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index d0b8739df92..c9e9a919395 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -73,6 +73,7 @@ void JoinProbeContext::prepareForHashProbe( HashJoinKeyMethod method, ASTTableJoin::Kind kind, bool has_other_condition, + bool has_other_eq_from_in_condition, const Names & key_names, const String & filter_column, const NameSet & probe_output_name_set, @@ -123,7 +124,14 @@ void JoinProbeContext::prepareForHashProbe( } if (kind == LeftOuterSemi || kind == LeftOuterAnti) { - semi_match_res.resize(rows); + left_semi_match_res.resize(rows); + if (has_other_eq_from_in_condition) + left_semi_match_null_res.resize(rows); + } + if ((kind == Semi || kind == Anti) && has_other_condition) + { + semi_selective_offsets.clear(); + semi_selective_offsets.reserve(rows); } if (SemiJoinProbeHelper::isSupported(kind, has_other_condition)) @@ -384,7 +392,7 @@ struct JoinProbeAdder RowPtr, size_t) { - ctx.semi_match_res[idx] = 1; + ctx.left_semi_match_res[idx] = 1; return false; } @@ -413,7 +421,7 @@ struct JoinProbeAdder static bool ALWAYS_INLINE addNotMatched(JoinProbeHelper &, JoinProbeContext & ctx, JoinProbeWorkerData &, size_t idx, size_t &) { - ctx.semi_match_res[idx] = 1; + ctx.left_semi_match_res[idx] = 1; return false; } @@ -523,7 +531,7 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData & w if constexpr (kind == LeftOuterSemi || kind == LeftOuterAnti) { // Sanity check - RUNTIME_CHECK(ctx.semi_match_res.size() == ctx.rows); + RUNTIME_CHECK(ctx.left_semi_match_res.size() == ctx.rows); } wd.insert_batch.clear(); @@ -1361,7 +1369,7 @@ Block JoinProbeHelper::genResultBlockForLeftOuterSemi(JoinProbeContext & ctx) auto * match_helper_column = typeid_cast(match_helper_column_ptr.get()); match_helper_column->getNullMapColumn().getData().resize_fill_zero(ctx.rows); auto * match_helper_res = &typeid_cast &>(match_helper_column->getNestedColumn()).getData(); - match_helper_res->swap(ctx.semi_match_res); + match_helper_res->swap(ctx.left_semi_match_res); res_block.getByPosition(match_helper_column_index).column = std::move(match_helper_column_ptr); diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 38bcd5f38b5..b61760f8f50 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -46,9 +46,11 @@ struct JoinProbeContext ssize_t not_matched_offsets_idx = -1; IColumn::Offsets not_matched_offsets; /// For left outer (anti) semi join. - PaddedPODArray semi_match_res; - /// For left outer (anti) semi join with null-eq-from-in conditions. - PaddedPODArray semi_match_null_res; + PaddedPODArray left_semi_match_res; + /// For left outer (anti) semi join with other-eq-from-in conditions. + PaddedPODArray left_semi_match_null_res; + /// For (anti) semi join with other conditions. + IColumn::Offsets semi_selective_offsets; /// For (left outer) (anti) semi join with other conditions. std::unique_ptr> semi_join_pending_probe_list; @@ -73,6 +75,7 @@ struct JoinProbeContext HashJoinKeyMethod method, ASTTableJoin::Kind kind, bool has_other_condition, + bool has_other_eq_from_in_condition, const Names & key_names, const String & filter_column, const NameSet & probe_output_name_set, diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp index b309d4744d6..5af6b0edcdb 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp @@ -140,6 +140,14 @@ Block SemiJoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData auto * probe_list = static_cast(ctx.semi_join_pending_probe_list.get()); RUNTIME_CHECK(probe_list->slotSize() == ctx.rows); + if constexpr (kind == LeftOuterSemi || kind == LeftOuterAnti) + { + // Sanity check + RUNTIME_CHECK(ctx.left_semi_match_res.size() == ctx.rows); + if (!join->non_equal_conditions.other_eq_cond_from_in_name.empty()) + RUNTIME_CHECK(ctx.left_semi_match_null_res.size() == ctx.rows); + } + size_t left_columns = join->left_sample_block_pruned.columns(); size_t right_columns = join->right_sample_block_pruned.columns(); if (!wd.result_block) @@ -176,6 +184,12 @@ Block SemiJoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData // Alternative: added_columns.clear(); but that is less explicit and may misleadingly imply the columns are discarded. for (size_t i = 0; i < right_columns; ++i) wd.result_block.safeGetByPosition(left_columns + i).column = std::move(added_columns[i]); + + if (ctx.isProbeFinished()) + { + + } + return join->output_block_after_finalize; } template diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h index 10e8b7b21af..e91054b2bc5 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h @@ -70,13 +70,13 @@ class SemiJoinPendingProbeList RUNTIME_CHECK(n <= UINT32_MAX); probe_rows.resize(n + 1); sentinel_idx = static_cast(n); - // Sentinel circular self-loop + // Initialize sentinel self-loop probe_rows[sentinel_idx].prev_idx = sentinel_idx; probe_rows[sentinel_idx].next_idx = sentinel_idx; #ifndef NDEBUG // reset should be called after all slots are removed - assert(slot_count == 0); + assert(linked_count == 0); // Isolate all slots for (IndexType i = 0; i < n; ++i) { @@ -95,7 +95,7 @@ class SemiJoinPendingProbeList #ifndef NDEBUG assert(idx < slotSize()); assert(probe_rows[idx].prev_idx == idx && probe_rows[idx].next_idx == idx); - ++slot_count; + ++linked_count; #endif IndexType tail = probe_rows[sentinel_idx].prev_idx; probe_rows[tail].next_idx = idx; @@ -110,8 +110,8 @@ class SemiJoinPendingProbeList #ifndef NDEBUG assert(idx < slotSize()); assert(probe_rows[idx].prev_idx != idx && probe_rows[idx].next_idx != idx); - assert(slot_count > 0); - --slot_count; + assert(linked_count > 0); + --linked_count; #endif IndexType prev = probe_rows[idx].prev_idx; IndexType next = probe_rows[idx].next_idx; @@ -126,7 +126,7 @@ class SemiJoinPendingProbeList PaddedPODArray probe_rows; IndexType sentinel_idx = 0; #ifndef NDEBUG - size_t slot_count = 0; + size_t linked_count = 0; #endif }; From 63688f91adaca818df29dcff13fda446e6c268fe Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 24 Apr 2025 09:44:37 +0800 Subject: [PATCH 48/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 5c83f34e37c..8ea72d15137 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -120,7 +120,7 @@ void JoinProbeContext::prepareForHashProbe( not_matched_offsets.clear(); } if (kind == LeftOuterSemi || kind == LeftOuterAnti) - left_semi_match_res.resize(rows); + left_semi_match_res.resize_fill_zero(rows); is_prepared = true; } @@ -374,9 +374,8 @@ struct JoinProbeAdder } static bool ALWAYS_INLINE - addNotMatched(JoinProbeHelper &, JoinProbeContext & ctx, JoinProbeWorkerData &, size_t idx, size_t &) + addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, size_t, size_t &) { - ctx.left_semi_match_res[idx] = 0; return false; } @@ -405,9 +404,8 @@ struct JoinProbeAdder } static bool ALWAYS_INLINE - addNotMatched(JoinProbeHelper &, JoinProbeContext & ctx, JoinProbeWorkerData &, size_t idx, size_t &) + addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, size_t, size_t &) { - ctx.left_semi_match_res[idx] = 1; return false; } From 9b1421c1f867181fb02d65921780a0195ffca6a6 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 24 Apr 2025 14:49:02 +0800 Subject: [PATCH 49/84] fix Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 8ea72d15137..c3cf03e224f 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -120,7 +120,10 @@ void JoinProbeContext::prepareForHashProbe( not_matched_offsets.clear(); } if (kind == LeftOuterSemi || kind == LeftOuterAnti) + { + left_semi_match_res.clear(); left_semi_match_res.resize_fill_zero(rows); + } is_prepared = true; } From 44c0ce0c0612fce4d920330c10ea179544d8be32 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 24 Apr 2025 16:04:14 +0800 Subject: [PATCH 50/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index c3cf03e224f..3b93c3857d0 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -394,21 +394,21 @@ struct JoinProbeAdder static bool ALWAYS_INLINE addMatched( JoinProbeHelper &, - JoinProbeContext & ctx, + JoinProbeContext &, JoinProbeWorkerData &, MutableColumns &, - size_t idx, + size_t, size_t &, RowPtr, size_t) { - ctx.left_semi_match_res[idx] = 0; return false; } static bool ALWAYS_INLINE - addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, size_t, size_t &) + addNotMatched(JoinProbeHelper &, JoinProbeContext & ctx, JoinProbeWorkerData &, size_t idx, size_t &) { + ctx.left_semi_match_res[idx] = 1; return false; } From fbb4e3f28d374d87e72431158fbab9ac0764e268 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Fri, 25 Apr 2025 01:44:50 +0800 Subject: [PATCH 51/84] u Signed-off-by: gengliqi --- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 13 +- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 3 +- .../src/Interpreters/JoinV2/SemiJoinProbe.cpp | 249 +++++++++++------- dbms/src/Interpreters/JoinV2/SemiJoinProbe.h | 145 ++-------- .../Interpreters/JoinV2/SemiJoinProbeList.h | 155 +++++++++++ 5 files changed, 334 insertions(+), 231 deletions(-) create mode 100644 dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 06248539221..5c2431afa6c 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -40,7 +40,9 @@ bool JoinProbeContext::isProbeFinished() const { return current_row_idx >= rows // For prefetching - && prefetch_active_states == 0; + && prefetch_active_states == 0 + // For (left outer) (anti) semi join with other conditions + && (semi_join_probe_list == nullptr || semi_join_probe_list->activeSlots() == 0); } bool JoinProbeContext::isAllFinished() const @@ -58,7 +60,6 @@ void JoinProbeContext::resetBlock(Block & block_) current_row_idx = 0; current_build_row_ptr = nullptr; current_row_is_matched = false; - semi_join_pending_probe_list.reset(); prefetch_active_states = 0; @@ -140,13 +141,7 @@ void JoinProbeContext::prepareForHashProbe( if (SemiJoinProbeHelper::isSupported(kind, has_other_condition)) { - if unlikely (!semi_join_pending_probe_list) - { - semi_join_pending_probe_list = decltype(semi_join_pending_probe_list)( - static_cast(new SemiJoinPendingProbeList), - [](void * ptr) { delete static_cast(ptr); }); - } - static_cast(semi_join_pending_probe_list.get())->reset(block.rows()); + //semi_join_probe_list->reset(rows); } is_prepared = true; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index b61760f8f50..3f4846702b7 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -52,7 +53,7 @@ struct JoinProbeContext /// For (anti) semi join with other conditions. IColumn::Offsets semi_selective_offsets; /// For (left outer) (anti) semi join with other conditions. - std::unique_ptr> semi_join_pending_probe_list; + std::unique_ptr semi_join_probe_list; size_t prefetch_active_states = 0; size_t prefetch_iter = 0; diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp index ceb27fd9b2e..2baeb05793b 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp @@ -14,6 +14,8 @@ #include #include + +#include "Interpreters/JoinV2/SemiJoinProbeList.h" #include "Parsers/ASTTablesInSelectQuery.h" #ifdef TIFLASH_ENABLE_AVX_SUPPORT @@ -25,77 +27,58 @@ namespace DB using enum ASTTableJoin::Kind; -enum class SemiJoinProbeResType : UInt8 +namespace +{ +template +void ALWAYS_INLINE setMatched(JoinProbeContext & ctx, size_t idx) +{ + static_assert(kind == Semi || kind == Anti || kind == LeftOuterSemi || kind == LeftOuterAnti); + if constexpr (kind == Semi) + { + ctx.semi_selective_offsets.push_back(idx); + } + else if constexpr (kind == LeftOuterSemi) + { + ctx.left_semi_match_res[idx] = 1; + } +} + +template +void ALWAYS_INLINE setNotMatched(JoinProbeContext & ctx, size_t idx, bool has_null_eq_from_in = false) { - FALSE_VALUE, - TRUE_VALUE, - NULL_VALUE, -}; - -namespace { - template - void ALWAYS_INLINE setMatched(JoinProbeContext & ctx, size_t idx) + static_assert(kind == Semi || kind == Anti || kind == LeftOuterSemi || kind == LeftOuterAnti); + if constexpr (check_other_eq_from_in_cond) { - static_assert(kind == Semi || kind == Anti || kind == LeftOuterSemi || kind == LeftOuterAnti); - if constexpr (kind == Semi) + if constexpr (kind == Anti) { - ctx.semi_selective_offsets.push_back(idx); + if (!has_null_eq_from_in) + ctx.semi_selective_offsets.push_back(idx); } else if constexpr (kind == LeftOuterSemi) + { + ctx.left_semi_match_null_res[idx] = has_null_eq_from_in; + } + else if constexpr (kind == LeftOuterAnti) { ctx.left_semi_match_res[idx] = 1; + ctx.left_semi_match_null_res[idx] = has_null_eq_from_in; } } - - template - void ALWAYS_INLINE setNotMatched(JoinProbeContext & ctx, size_t idx, bool has_null_eq_from_in = false) + else { - static_assert(kind == Semi || kind == Anti || kind == LeftOuterSemi || kind == LeftOuterAnti); - if constexpr (check_other_eq_from_in_cond) + if constexpr (kind == Anti) { - if constexpr (kind == Anti) - { - if (!has_null_eq_from_in) - ctx.semi_selective_offsets.push_back(idx); - } - else if constexpr (kind == LeftOuterSemi) - { - ctx.left_semi_match_null_res[idx] = has_null_eq_from_in; - } - else if constexpr (kind == LeftOuterAnti) - { - ctx.left_semi_match_res[idx] = 1; - ctx.left_semi_match_null_res[idx] = has_null_eq_from_in; - } + ctx.semi_selective_offsets.push_back(idx); } - else + else if constexpr (kind == LeftOuterAnti) { - if constexpr (kind == Anti) - { - ctx.semi_selective_offsets.push_back(idx); - } - else if constexpr (kind == LeftOuterAnti) - { - ctx.left_semi_match_res[idx] = 1; - } + ctx.left_semi_match_res[idx] = 1; } } +} } // namespace -template -struct SemiJoinProbeAdder -{ - static void ALWAYS_INLINE setMatched( - JoinProbeContext & ctx, - size_t idx) - { - ctx.semi_selective_offsets.push_back(idx); - } - - static void ALWAYS_INLINE addNotMatched(JoinProbeContext &, size_t) {} -}; - SemiJoinProbeHelper::SemiJoinProbeHelper(const HashJoin * join) : JoinProbeHelperUtil(join->settings, join->row_layout) , join(join) @@ -104,26 +87,28 @@ SemiJoinProbeHelper::SemiJoinProbeHelper(const HashJoin * join) // SemiJoinProbeHelper only handles semi join with other conditions RUNTIME_CHECK(join->has_other_condition); -#define CALL3(KeyGetter, JoinType, has_other_eq_from_in_cond, tagged_pointer) \ - { \ - func_ptr_has_null = &SemiJoinProbeHelper::probeImpl; \ - func_ptr_no_null = &SemiJoinProbeHelper::probeImpl; \ +#define CALL3(KeyGetter, JoinType, has_other_eq_from_in_cond, tagged_pointer) \ + { \ + func_ptr_has_null \ + = &SemiJoinProbeHelper::probeImpl; \ + func_ptr_no_null \ + = &SemiJoinProbeHelper::probeImpl; \ } -#define CALL2(KeyGetter, JoinType, has_other_eq_from_in_cond) \ - { \ - if (pointer_table.enableTaggedPointer()) \ - CALL3(KeyGetter, JoinType, has_other_eq_from_in_cond, true) \ - else \ - CALL3(KeyGetter, JoinType, has_other_eq_from_in_cond, false) \ +#define CALL2(KeyGetter, JoinType, has_other_eq_from_in_cond) \ + { \ + if (pointer_table.enableTaggedPointer()) \ + CALL3(KeyGetter, JoinType, has_other_eq_from_in_cond, true) \ + else \ + CALL3(KeyGetter, JoinType, has_other_eq_from_in_cond, false) \ } -#define CALL1(KeyGetter, JoinType) \ - { \ +#define CALL1(KeyGetter, JoinType) \ + { \ if (join->non_equal_conditions.other_eq_cond_from_in_name.empty()) \ - CALL2(KeyGetter, JoinType, false) \ - else \ - CALL2(KeyGetter, JoinType, true) \ + CALL2(KeyGetter, JoinType, false) \ + else \ + CALL2(KeyGetter, JoinType, true) \ } #define CALL(KeyGetter) \ @@ -210,15 +195,22 @@ Block SemiJoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData Stopwatch watch; if (pointer_table.enableProbePrefetch()) { - probeFillColumnsPrefetch(ctx, wd, added_columns); - - probeFillColumnsPrefetch(ctx, wd, added_columns); + //probeFillColumnsPrefetch( + // ctx, + // wd, + // added_columns); } else { - probeFillColumns(ctx, wd, added_columns); - - probeFillColumns(ctx, wd, added_columns); + probeFillColumnsFromList( + ctx, + wd, + added_columns); + if (wd.result_block.rows() < settings.max_block_size) + probeFillColumns( + ctx, + wd, + added_columns); } wd.probe_hash_table_time += watch.elapsedFromLastTime(); @@ -227,13 +219,13 @@ Block SemiJoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData for (size_t i = 0; i < right_columns; ++i) wd.result_block.safeGetByPosition(left_columns + i).column = std::move(added_columns[i]); - if (ctx.isProbeFinished()) - { - - } + if (ctx.isProbeFinished()) {} return join->output_block_after_finalize; } +static constexpr UInt16 INITIAL_PACE = 4; +static constexpr UInt16 MAX_PACE = 8192; + SEMI_JOIN_PROBE_HELPER_TEMPLATE void SemiJoinProbeHelper::probeFillColumns( JoinProbeContext & ctx, @@ -243,14 +235,12 @@ void SemiJoinProbeHelper::probeFillColumns( using KeyGetterType = typename KeyGetter::Type; using Hash = typename KeyGetter::Hash; using HashValueType = typename KeyGetter::HashValueType; - using Adder = SemiJoinProbeAdder; auto & key_getter = *static_cast(ctx.key_getter.get()); - auto * probe_list = static_cast(ctx.semi_join_pending_probe_list.get()); - RUNTIME_CHECK(probe_list->slotSize() == ctx.rows); + auto * probe_list = static_cast *>(ctx.semi_join_probe_list.get()); + RUNTIME_CHECK(probe_list->slotCapacity() == ctx.rows); size_t current_offset = wd.result_block.rows(); size_t idx = ctx.current_row_idx; - bool is_matched = ctx.current_row_is_matched; size_t collision = 0; size_t key_offset = sizeof(RowPtr); if constexpr (KeyGetterType::joinKeyCompareHashFirst()) @@ -269,7 +259,7 @@ void SemiJoinProbeHelper::probeFillColumns( } } - const auto & key = key_getter.getJoinKey(idx); + const auto & key = key_getter.getJoinKeyWithBuffer(idx); auto hash = static_cast(Hash()(key)); UInt16 hash_tag = hash & ROW_PTR_TAG_MASK; RowPtr ptr = pointer_table.getHeadPointer(hash); @@ -290,11 +280,8 @@ void SemiJoinProbeHelper::probeFillColumns( ptr = removeRowPtrTag(ptr); } - probe_list->append(idx); - auto & pending_probe_row = probe_list->at(idx); - pending_probe_row.has_null_eq_from_in = false; - pending_probe_row.pace = 4; - size_t end_offset = std::min(settings.max_block_size, current_offset + pending_probe_row.pace); + size_t end_offset = std::min(settings.max_block_size, current_offset + INITIAL_PACE); + size_t prev_offset = current_offset; while (true) { const auto & key2 = key_getter.deserializeJoinKey(ptr + key_offset); @@ -316,7 +303,18 @@ void SemiJoinProbeHelper::probeFillColumns( if (ptr == nullptr) break; } - pending_probe_row.build_row_ptr = ptr; + if (prev_offset == current_offset) + { + setNotMatched(ctx, idx); + continue; + } + probe_list->append(idx); + auto & probe_row = probe_list->at(idx); + probe_row.build_row_ptr = ptr; + probe_row.has_null_eq_from_in = false; + probe_row.pace = INITIAL_PACE * 2; + probe_row.hash = hash; + probe_row.key = key; if unlikely (current_offset >= settings.max_block_size) { if (ptr == nullptr) @@ -326,15 +324,82 @@ void SemiJoinProbeHelper::probeFillColumns( } ctx.current_row_idx = idx; - ctx.current_row_is_matched = is_matched; wd.collision += collision; } SEMI_JOIN_PROBE_HELPER_TEMPLATE -void SemiJoinProbeHelper::probeFillColumnsPrefetch( +void NO_INLINE SemiJoinProbeHelper::probeFillColumnsFromList( JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns) -{} +{ + using KeyGetterType = typename KeyGetter::Type; + using HashValueType = typename KeyGetter::HashValueType; + + auto & key_getter = *static_cast(ctx.key_getter.get()); + auto * probe_list = static_cast *>(ctx.semi_join_probe_list.get()); + RUNTIME_CHECK(probe_list->slotCapacity() == ctx.rows); + size_t current_offset = wd.result_block.rows(); + size_t collision = 0; + size_t key_offset = sizeof(RowPtr); + if constexpr (KeyGetterType::joinKeyCompareHashFirst()) + { + key_offset += sizeof(HashValueType); + } + auto iter_end = probe_list->end(); + for (auto iter = probe_list->begin(); iter != iter_end;) + { + auto & probe_row = *iter; + RowPtr ptr = probe_row.build_row_ptr; + auto idx = iter.getIndex(); + if (ptr == nullptr) + { + setNotMatched(ctx, idx, probe_row.has_null_eq_from_in); + ++iter; + probe_list->remove(idx); + continue; + } + size_t end_offset = std::min(settings.max_block_size, current_offset + probe_row.pace); + size_t prev_offset = current_offset; + while (true) + { + const auto & key2 = key_getter.deserializeJoinKey(ptr + key_offset); + bool key_is_equal = joinKeyIsEqual(key_getter, probe_row.key, key2, probe_row.hash, ptr); + collision += !key_is_equal; + if (key_is_equal) + { + ++current_offset; + wd.selective_offsets.push_back(idx); + insertRowToBatch(wd, added_columns, ptr + key_offset + key_getter.getRequiredKeyOffset(key2)); + if unlikely (current_offset >= end_offset) + { + ptr = getNextRowPtr(ptr); + break; + } + } + + ptr = getNextRowPtr(ptr); + if (ptr == nullptr) + break; + } + if (prev_offset == current_offset) + { + setNotMatched(ctx, idx, probe_row.has_null_eq_from_in); + auto idx = iter.getIndex(); + ++iter; + probe_list->remove(idx); + continue; + } + probe_row.build_row_ptr = ptr; + if (current_offset - prev_offset >= probe_row.pace) + probe_row.pace = std::min(MAX_PACE, probe_row.pace * 2U); + if unlikely (current_offset >= settings.max_block_size) + break; + + ++iter; + } + + wd.collision += collision; +} } // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h index ef17d0377c0..ab5291505eb 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h @@ -1,4 +1,4 @@ -// Copyright 2024 PingCAP, Inc. +// Copyright 2025 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,133 +15,18 @@ #pragma once #include +#include namespace DB { -/// A reusable, index‑based doubly‑linked circular list for managing semi join pending probe rows. -/// Supports O(1) append/remove by index. -class SemiJoinPendingProbeList -{ -public: - using IndexType = UInt32; - - struct PendingProbeRow - { - RowPtr build_row_ptr; - bool has_null_eq_from_in; - UInt32 pace; - }; - - class Iterator - { - public: - Iterator(SemiJoinPendingProbeList & list, IndexType idx) - : list(list) - , idx(idx) - {} - - inline IndexType getIndex() const { return idx; } - - PendingProbeRow & operator*() const { return list.probe_rows[idx]; } - PendingProbeRow * operator->() const { return &list.probe_rows[idx]; } - - Iterator & operator++() - { - idx = list.probe_rows[idx].next_idx; - return *this; - } - - bool operator!=(const Iterator & other) const { return idx != other.idx; } - - private: - SemiJoinPendingProbeList & list; - IndexType idx; - }; - - SemiJoinPendingProbeList() = default; - - /// After reset(n), it holds n entries plus a sentinel at index n. - void reset(size_t n) - { - RUNTIME_CHECK(n <= UINT32_MAX); - probe_rows.resize(n + 1); - sentinel_idx = static_cast(n); - // Initialize sentinel self-loop - probe_rows[sentinel_idx].prev_idx = sentinel_idx; - probe_rows[sentinel_idx].next_idx = sentinel_idx; - -#ifndef NDEBUG - // reset should be called after all slots are removed - assert(linked_count == 0); - // Isolate all slots - for (IndexType i = 0; i < n; ++i) - { - probe_rows[i].prev_idx = i; - probe_rows[i].next_idx = i; - } -#endif - } - - /// Returns the number of usable slots in the list (excluding the sentinel). - inline size_t slotSize() const { return probe_rows.size() - 1; } - - /// Append an existing slot by index at the tail (before sentinel). - inline void append(IndexType idx) - { -#ifndef NDEBUG - assert(idx < slotSize()); - assert(probe_rows[idx].prev_idx == idx && probe_rows[idx].next_idx == idx); - ++linked_count; -#endif - IndexType tail = probe_rows[sentinel_idx].prev_idx; - probe_rows[tail].next_idx = idx; - probe_rows[idx].prev_idx = tail; - probe_rows[idx].next_idx = sentinel_idx; - probe_rows[sentinel_idx].prev_idx = idx; - } - - /// Remove a slot by index from the list. - inline void remove(IndexType idx) - { -#ifndef NDEBUG - assert(idx < slotSize()); - assert(probe_rows[idx].prev_idx != idx && probe_rows[idx].next_idx != idx); - assert(linked_count > 0); - --linked_count; -#endif - IndexType prev = probe_rows[idx].prev_idx; - IndexType next = probe_rows[idx].next_idx; - probe_rows[prev].next_idx = next; - probe_rows[next].prev_idx = prev; - } - - Iterator begin() { return Iterator(*this, probe_rows[sentinel_idx].next_idx); } - Iterator end() { return Iterator(*this, sentinel_idx); } - - PendingProbeRow & at(IndexType idx) { assert(idx < slotSize()); return probe_rows[idx]; } - const PendingProbeRow & at(IndexType idx) const { assert(idx < slotSize()); return probe_rows[idx]; } - -private: - struct WrapPendingProbeRow : PendingProbeRow - { - /// Embedded list indexes - IndexType prev_idx; - IndexType next_idx; - }; - - PaddedPODArray probe_rows; - IndexType sentinel_idx = 0; -#ifndef NDEBUG - size_t linked_count = 0; -#endif -}; - -template -struct SemiJoinProbeAdder; - #define SEMI_JOIN_PROBE_HELPER_TEMPLATE \ -template + template < \ + typename KeyGetter, \ + ASTTableJoin::Kind kind, \ + bool has_null_map, \ + bool has_other_eq_from_in_cond, \ + bool tagged_pointer> class HashJoin; class SemiJoinProbeHelper : public JoinProbeHelperUtil @@ -159,15 +44,18 @@ class SemiJoinProbeHelper : public JoinProbeHelperUtil SEMI_JOIN_PROBE_HELPER_TEMPLATE void NO_INLINE probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns); - SEMI_JOIN_PROBE_HELPER_TEMPLATE void NO_INLINE - probeFillColumnsPrefetch(JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns); + probeFillColumnsFromList(JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns); -private: - template - friend struct SemiJoinProbeAdder; + //SEMI_JOIN_PROBE_HELPER_TEMPLATE + //void NO_INLINE + //probeFillColumnsPrefetch(JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns); + //SEMI_JOIN_PROBE_HELPER_TEMPLATE + //void NO_INLINE + //probeFillColumnsPrefetchFromList(JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns); +private: using FuncType = Block (SemiJoinProbeHelper::*)(JoinProbeContext &, JoinProbeWorkerData &); FuncType func_ptr_has_null = nullptr; FuncType func_ptr_no_null = nullptr; @@ -176,5 +64,4 @@ class SemiJoinProbeHelper : public JoinProbeHelperUtil const HashJoinPointerTable & pointer_table; }; - } // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h new file mode 100644 index 00000000000..180d74f1fb3 --- /dev/null +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h @@ -0,0 +1,155 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace DB +{ + +class SemiJoinProbeListBase +{ +public: + virtual ~SemiJoinProbeListBase() = 0; + virtual size_t activeSlots() const = 0; +}; + +/// A reusable, index‑based doubly‑linked circular list for managing semi join pending probe rows. +/// Supports O(1) append/remove by index. +template +class SemiJoinProbeList final : public SemiJoinProbeListBase +{ +public: + using IndexType = UInt32; + + struct ProbeRow + { + using KeyGetterType = typename KeyGetter::Type; + using KeyType = typename KeyGetterType::KeyType; + using HashValueType = typename KeyGetter::HashValueType; + + RowPtr build_row_ptr; + bool has_null_eq_from_in; + UInt16 pace; + HashValueType hash; + KeyType key; + }; + + class Iterator + { + public: + Iterator(SemiJoinProbeList & list, IndexType idx) + : list(list) + , idx(idx) + {} + + inline IndexType getIndex() const { return idx; } + + ProbeRow & operator*() const { return list.probe_rows[idx]; } + ProbeRow * operator->() const { return &list.probe_rows[idx]; } + + Iterator & operator++() + { + idx = list.probe_rows[idx].next_idx; + return *this; + } + + bool operator!=(const Iterator & other) const { return idx != other.idx; } + + private: + SemiJoinProbeList & list; + IndexType idx; + }; + + SemiJoinProbeList() = default; + + /// After reset(n), it holds n entries plus a sentinel at index n. + void reset(size_t n) + { + // reset should be called after all active slots are removed + RUNTIME_CHECK(active_count == 0); + RUNTIME_CHECK(n <= UINT32_MAX); + probe_rows.resize(n + 1); + sentinel_idx = static_cast(n); + // Initialize sentinel self-loop + probe_rows[sentinel_idx].prev_idx = sentinel_idx; + probe_rows[sentinel_idx].next_idx = sentinel_idx; + +#ifndef NDEBUG + // Isolate all slots + for (IndexType i = 0; i < n; ++i) + { + probe_rows[i].prev_idx = i; + probe_rows[i].next_idx = i; + } +#endif + } + + /// Returns the number of usable slots in the list (excluding the sentinel). + inline size_t slotCapacity() const { return probe_rows.size() - 1; } + + size_t activeSlots() const override { return active_count; } + + /// Append an existing slot by index at the tail (before sentinel). + inline void append(IndexType idx) + { +#ifndef NDEBUG + assert(idx < slotCapacity()); + assert(probe_rows[idx].prev_idx == idx && probe_rows[idx].next_idx == idx); +#endif + ++active_count; + IndexType tail = probe_rows[sentinel_idx].prev_idx; + probe_rows[tail].next_idx = idx; + probe_rows[idx].prev_idx = tail; + probe_rows[idx].next_idx = sentinel_idx; + probe_rows[sentinel_idx].prev_idx = idx; + } + + /// Remove a slot by index from the list. + inline void remove(IndexType idx) + { +#ifndef NDEBUG + assert(idx < slotCapacity()); + assert(probe_rows[idx].prev_idx != idx && probe_rows[idx].next_idx != idx); + assert(active_count > 0); +#endif + --active_count; + IndexType prev = probe_rows[idx].prev_idx; + IndexType next = probe_rows[idx].next_idx; + probe_rows[prev].next_idx = next; + probe_rows[next].prev_idx = prev; + } + + Iterator begin() { return Iterator(*this, probe_rows[sentinel_idx].next_idx); } + Iterator end() { return Iterator(*this, sentinel_idx); } + + ProbeRow & at(IndexType idx) { assert(idx < slotCapacity()); return probe_rows[idx]; } + const ProbeRow & at(IndexType idx) const { assert(idx < slotCapacity()); return probe_rows[idx]; } + +private: + struct WrapProbeRow : ProbeRow + { + /// Embedded list indexes + IndexType prev_idx; + IndexType next_idx; + }; + + PaddedPODArray probe_rows; + IndexType sentinel_idx = 0; + size_t active_count = 0; +}; + +} // namespace DB From d1499bdc0e27092bd35e0b23d4f1b647dbb9c2d1 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Fri, 25 Apr 2025 01:50:41 +0800 Subject: [PATCH 52/84] format Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h index 180d74f1fb3..4013b119611 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h @@ -136,8 +136,16 @@ class SemiJoinProbeList final : public SemiJoinProbeListBase Iterator begin() { return Iterator(*this, probe_rows[sentinel_idx].next_idx); } Iterator end() { return Iterator(*this, sentinel_idx); } - ProbeRow & at(IndexType idx) { assert(idx < slotCapacity()); return probe_rows[idx]; } - const ProbeRow & at(IndexType idx) const { assert(idx < slotCapacity()); return probe_rows[idx]; } + ProbeRow & at(IndexType idx) + { + assert(idx < slotCapacity()); + return probe_rows[idx]; + } + const ProbeRow & at(IndexType idx) const + { + assert(idx < slotCapacity()); + return probe_rows[idx]; + } private: struct WrapProbeRow : ProbeRow From 418f7c4d3ac327d1300407712d60f9ac830e0f56 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Fri, 25 Apr 2025 02:14:10 +0800 Subject: [PATCH 53/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp index 2baeb05793b..a6ff7d415b8 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp @@ -352,16 +352,9 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsFromList( auto & probe_row = *iter; RowPtr ptr = probe_row.build_row_ptr; auto idx = iter.getIndex(); - if (ptr == nullptr) - { - setNotMatched(ctx, idx, probe_row.has_null_eq_from_in); - ++iter; - probe_list->remove(idx); - continue; - } size_t end_offset = std::min(settings.max_block_size, current_offset + probe_row.pace); size_t prev_offset = current_offset; - while (true) + while (ptr != nullptr) { const auto & key2 = key_getter.deserializeJoinKey(ptr + key_offset); bool key_is_equal = joinKeyIsEqual(key_getter, probe_row.key, key2, probe_row.hash, ptr); @@ -379,8 +372,6 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsFromList( } ptr = getNextRowPtr(ptr); - if (ptr == nullptr) - break; } if (prev_offset == current_offset) { From 4ba216dde2b6bfb47807bd657e84aa2154f4a216 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 29 Apr 2025 04:18:09 +0800 Subject: [PATCH 54/84] u Signed-off-by: gengliqi --- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 29 +- .../src/Interpreters/JoinV2/SemiJoinProbe.cpp | 340 +++++++++++++++--- dbms/src/Interpreters/JoinV2/SemiJoinProbe.h | 10 +- .../Interpreters/JoinV2/SemiJoinProbeList.h | 31 +- 4 files changed, 328 insertions(+), 82 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index a68da79e4f3..0c65be2f559 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -658,11 +658,8 @@ void JoinProbeHelper::probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerDa RowPtr ptr = ctx.current_build_row_ptr; bool is_matched = ctx.current_row_is_matched; size_t collision = 0; - size_t key_offset = sizeof(RowPtr); - if constexpr (KeyGetterType::joinKeyCompareHashFirst()) - { - key_offset += sizeof(HashValueType); - } + constexpr size_t key_offset + = sizeof(RowPtr) + (KeyGetterType::joinKeyCompareHashFirst() ? sizeof(HashValueType) : 0); #define NOT_MATCHED(not_matched) \ if constexpr (Adder::need_not_matched) \ @@ -801,8 +798,10 @@ struct ProbePrefetchState KeyType key{}; union { - RowPtr ptr = nullptr; - std::atomic * pointer_ptr; + /// Used when stage is FindHeader + std::atomic * pointer_ptr = nullptr; + /// Used when stage is FindNext + RowPtr ptr; }; }; @@ -820,10 +819,11 @@ void JoinProbeHelper::probeFillColumnsPrefetch( using Adder = JoinProbeAdder; auto & key_getter = *static_cast(ctx.key_getter.get()); + const size_t probe_prefetch_step = settings.probe_prefetch_step; if unlikely (!ctx.prefetch_states) { ctx.prefetch_states = decltype(ctx.prefetch_states)( - static_cast(new ProbePrefetchState[settings.probe_prefetch_step]), + static_cast(new ProbePrefetchState[probe_prefetch_step]), [](void * ptr) { delete[] static_cast *>(ptr); }); } auto * states = static_cast *>(ctx.prefetch_states.get()); @@ -833,11 +833,8 @@ void JoinProbeHelper::probeFillColumnsPrefetch( size_t k = ctx.prefetch_iter; size_t current_offset = wd.result_block.rows(); size_t collision = 0; - size_t key_offset = sizeof(RowPtr); - if constexpr (KeyGetterType::joinKeyCompareHashFirst()) - { - key_offset += sizeof(HashValueType); - } + constexpr size_t key_offset + = sizeof(RowPtr) + (KeyGetterType::joinKeyCompareHashFirst() ? sizeof(HashValueType) : 0); #define NOT_MATCHED(not_matched, idx) \ if constexpr (Adder::need_not_matched) \ @@ -850,7 +847,6 @@ void JoinProbeHelper::probeFillColumnsPrefetch( } \ } - const size_t probe_prefetch_step = settings.probe_prefetch_step; while (idx < ctx.rows || active_states > 0) { k = k == probe_prefetch_step ? 0 : k; @@ -864,11 +860,10 @@ void JoinProbeHelper::probeFillColumnsPrefetch( const auto & key2 = key_getter.deserializeJoinKey(ptr + key_offset); bool key_is_equal = joinKeyIsEqual(key_getter, state->key, key2, state->hash, ptr); collision += !key_is_equal; + if constexpr (Adder::need_not_matched) + state->is_matched |= key_is_equal; if (key_is_equal) { - if constexpr (Adder::need_not_matched) - state->is_matched = true; - if constexpr (Adder::need_matched) { bool is_end = Adder::addMatched( diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp index a6ff7d415b8..a4fab61f345 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp @@ -17,6 +17,7 @@ #include "Interpreters/JoinV2/SemiJoinProbeList.h" #include "Parsers/ASTTablesInSelectQuery.h" +#include "ext/scope_guard.h" #ifdef TIFLASH_ENABLE_AVX_SUPPORT ASSERT_USE_AVX2_COMPILE_FLAG @@ -195,22 +196,17 @@ Block SemiJoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData Stopwatch watch; if (pointer_table.enableProbePrefetch()) { - //probeFillColumnsPrefetch( - // ctx, - // wd, - // added_columns); + probeFillColumnsPrefetch( + ctx, + wd, + added_columns); } else { - probeFillColumnsFromList( + probeFillColumns( ctx, wd, added_columns); - if (wd.result_block.rows() < settings.max_block_size) - probeFillColumns( - ctx, - wd, - added_columns); } wd.probe_hash_table_time += watch.elapsedFromLastTime(); @@ -227,10 +223,8 @@ static constexpr UInt16 INITIAL_PACE = 4; static constexpr UInt16 MAX_PACE = 8192; SEMI_JOIN_PROBE_HELPER_TEMPLATE -void SemiJoinProbeHelper::probeFillColumns( - JoinProbeContext & ctx, - JoinProbeWorkerData & wd, - MutableColumns & added_columns) +void NO_INLINE +SemiJoinProbeHelper::probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns) { using KeyGetterType = typename KeyGetter::Type; using Hash = typename KeyGetter::Hash; @@ -240,7 +234,6 @@ void SemiJoinProbeHelper::probeFillColumns( auto * probe_list = static_cast *>(ctx.semi_join_probe_list.get()); RUNTIME_CHECK(probe_list->slotCapacity() == ctx.rows); size_t current_offset = wd.result_block.rows(); - size_t idx = ctx.current_row_idx; size_t collision = 0; size_t key_offset = sizeof(RowPtr); if constexpr (KeyGetterType::joinKeyCompareHashFirst()) @@ -248,6 +241,59 @@ void SemiJoinProbeHelper::probeFillColumns( key_offset += sizeof(HashValueType); } + SCOPE_EXIT({ + flushInsertBatch(wd, added_columns); + fillNullMapWithZero(added_columns); + + wd.collision += collision; + }); + + auto iter_end = probe_list->end(); + for (auto iter = probe_list->begin(); iter != iter_end;) + { + auto & probe_row = *iter; + RowPtr ptr = probe_row.build_row_ptr; + auto idx = iter.getIndex(); + size_t end_offset = std::min(settings.max_block_size, current_offset + probe_row.pace); + size_t prev_offset = current_offset; + while (ptr != nullptr) + { + const auto & key2 = key_getter.deserializeJoinKey(ptr + key_offset); + bool key_is_equal = joinKeyIsEqual(key_getter, probe_row.key, key2, probe_row.hash, ptr); + collision += !key_is_equal; + current_offset += key_is_equal; + if (key_is_equal) + { + wd.selective_offsets.push_back(idx); + insertRowToBatch(wd, added_columns, ptr + key_offset + key_getter.getRequiredKeyOffset(key2)); + if unlikely (current_offset >= end_offset) + { + ptr = getNextRowPtr(ptr); + break; + } + } + + ptr = getNextRowPtr(ptr); + } + if (prev_offset == current_offset) + { + setNotMatched(ctx, idx, probe_row.has_null_eq_from_in); + iter = probe_list->remove(iter); + continue; + } + probe_row.build_row_ptr = ptr; + if (current_offset - prev_offset >= probe_row.pace) + probe_row.pace = std::min(MAX_PACE, probe_row.pace * 2U); + if unlikely (current_offset >= settings.max_block_size) + break; + + ++iter; + } + + if (current_offset >= settings.max_block_size) + return; + + size_t idx = ctx.current_row_idx; for (; idx < ctx.rows; ++idx) { if constexpr (has_null_map) @@ -273,7 +319,6 @@ void SemiJoinProbeHelper::probeFillColumns( { if (!containOtherTag(ptr, hash_tag)) { - ptr = nullptr; setNotMatched(ctx, idx); continue; } @@ -287,9 +332,9 @@ void SemiJoinProbeHelper::probeFillColumns( const auto & key2 = key_getter.deserializeJoinKey(ptr + key_offset); bool key_is_equal = joinKeyIsEqual(key_getter, key, key2, hash, ptr); collision += !key_is_equal; + current_offset += key_is_equal; if (key_is_equal) { - ++current_offset; wd.selective_offsets.push_back(idx); insertRowToBatch(wd, added_columns, ptr + key_offset + key_getter.getRequiredKeyOffset(key2)); if unlikely (current_offset >= end_offset) @@ -311,85 +356,276 @@ void SemiJoinProbeHelper::probeFillColumns( probe_list->append(idx); auto & probe_row = probe_list->at(idx); probe_row.build_row_ptr = ptr; - probe_row.has_null_eq_from_in = false; - probe_row.pace = INITIAL_PACE * 2; + if constexpr (has_other_eq_from_in_cond) + probe_row.has_null_eq_from_in = false; + probe_row.pace = std::min(MAX_PACE, INITIAL_PACE * 2U); probe_row.hash = hash; probe_row.key = key; if unlikely (current_offset >= settings.max_block_size) { - if (ptr == nullptr) - ++idx; + ++idx; break; } } - ctx.current_row_idx = idx; - wd.collision += collision; } +enum class ProbePrefetchStage : UInt8 +{ + None, + FindHeader, + FindNext, +}; + +template +struct ProbePrefetchState +{ + using KeyGetterType = typename KeyGetter::Type; + using KeyType = typename KeyGetterType::KeyType; + using HashValueType = typename KeyGetter::HashValueType; + + ProbePrefetchStage stage = ProbePrefetchStage::None; + bool is_matched : 1 = false; + bool has_null_eq_from_in : 1 = false; + union + { + /// Used when stage is FindHeader + UInt16 hash_tag = 0; + /// Used when stage is FindNext + UInt16 remaining_pace; + }; + UInt32 index = 0; + HashValueType hash = 0; + KeyType key{}; + union + { + /// Used when stage is FindHeader + std::atomic * pointer_ptr = nullptr; + /// Used when stage is FindNext + RowPtr ptr; + }; +}; + +#define PREFETCH_READ(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */) + SEMI_JOIN_PROBE_HELPER_TEMPLATE -void NO_INLINE SemiJoinProbeHelper::probeFillColumnsFromList( +void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns) { using KeyGetterType = typename KeyGetter::Type; + using Hash = typename KeyGetter::Hash; using HashValueType = typename KeyGetter::HashValueType; auto & key_getter = *static_cast(ctx.key_getter.get()); auto * probe_list = static_cast *>(ctx.semi_join_probe_list.get()); RUNTIME_CHECK(probe_list->slotCapacity() == ctx.rows); - size_t current_offset = wd.result_block.rows(); - size_t collision = 0; - size_t key_offset = sizeof(RowPtr); - if constexpr (KeyGetterType::joinKeyCompareHashFirst()) + const size_t probe_prefetch_step = settings.probe_prefetch_step; + if unlikely (!ctx.prefetch_states) { - key_offset += sizeof(HashValueType); + ctx.prefetch_states = decltype(ctx.prefetch_states)( + static_cast(new ProbePrefetchState[probe_prefetch_step]), + [](void * ptr) { delete[] static_cast *>(ptr); }); } + auto * states = static_cast *>(ctx.prefetch_states.get()); + + size_t idx = ctx.current_row_idx; + size_t active_states = ctx.prefetch_active_states; + size_t k = ctx.prefetch_iter; + size_t current_offset = wd.result_block.rows(); + size_t collision = 0; + constexpr size_t key_offset + = sizeof(RowPtr) + (KeyGetterType::joinKeyCompareHashFirst() ? sizeof(HashValueType) : 0); + + size_t list_active_slots = probe_list->activeSlots(); + auto iter = probe_list->begin(); auto iter_end = probe_list->end(); - for (auto iter = probe_list->begin(); iter != iter_end;) + while (idx < ctx.rows || active_states > 0 || list_active_slots > 0) { - auto & probe_row = *iter; - RowPtr ptr = probe_row.build_row_ptr; - auto idx = iter.getIndex(); - size_t end_offset = std::min(settings.max_block_size, current_offset + probe_row.pace); - size_t prev_offset = current_offset; - while (ptr != nullptr) + k = k == probe_prefetch_step ? 0 : k; + auto * state = &states[k]; + if (state->stage == ProbePrefetchStage::FindNext) { + RowPtr ptr = state->ptr; + RowPtr next_ptr = getNextRowPtr(ptr); + const auto & key2 = key_getter.deserializeJoinKey(ptr + key_offset); - bool key_is_equal = joinKeyIsEqual(key_getter, probe_row.key, key2, probe_row.hash, ptr); + bool key_is_equal = joinKeyIsEqual(key_getter, state->key, key2, state->hash, ptr); collision += !key_is_equal; + state->is_matched |= key_is_equal; + current_offset += key_is_equal; + state->remaining_pace -= key_is_equal; + bool remaining_pace_is_zero = false; if (key_is_equal) { - ++current_offset; - wd.selective_offsets.push_back(idx); + wd.selective_offsets.push_back(state->index); insertRowToBatch(wd, added_columns, ptr + key_offset + key_getter.getRequiredKeyOffset(key2)); - if unlikely (current_offset >= end_offset) + if unlikely (current_offset >= settings.max_block_size) { - ptr = getNextRowPtr(ptr); + probe_list->at(state->index).build_row_ptr = next_ptr; + state->stage = ProbePrefetchStage::None; + --active_states; break; } + if unlikely (state->remaining_pace == 0) + { + auto & probe_row = probe_list->at(state->index); + probe_row.build_row_ptr = next_ptr; + probe_row.pace = std::min(MAX_PACE, INITIAL_PACE * 2U); + remaining_pace_is_zero = true; + } } - ptr = getNextRowPtr(ptr); + if likely (!remaining_pace_is_zero) + { + if (next_ptr) + { + PREFETCH_READ(next_ptr); + state->ptr = next_ptr; + ++k; + continue; + } + + probe_list->at(state->index).build_row_ptr = next_ptr; + if (!state->is_matched) + { + setNotMatched(ctx, state->index, state->has_null_eq_from_in); + probe_list->remove(state->index); + } + } + + state->stage = ProbePrefetchStage::None; + --active_states; } - if (prev_offset == current_offset) + else if (state->stage == ProbePrefetchStage::FindHeader) { - setNotMatched(ctx, idx, probe_row.has_null_eq_from_in); - auto idx = iter.getIndex(); + RowPtr ptr = state->pointer_ptr->load(std::memory_order_relaxed); + if (ptr) + { + bool forward = true; + if constexpr (tagged_pointer) + { + if (containOtherTag(ptr, state->hash_tag)) + ptr = removeRowPtrTag(ptr); + else + forward = false; + } + if (forward) + { + PREFETCH_READ(ptr); + state->stage = ProbePrefetchStage::FindNext; + state->is_matched = false; + if constexpr (has_other_eq_from_in_cond) + state->has_null_eq_from_in = false; + state->remaining_pace = INITIAL_PACE; + state->ptr = ptr; + ++k; + + probe_list->append(state->index); + auto & probe_row = probe_list->at(state->index); + //probe_row.build_row_ptr = ptr; + if constexpr (has_other_eq_from_in_cond) + probe_row.has_null_eq_from_in = false; + probe_row.pace = INITIAL_PACE; + if constexpr (KeyGetterType::joinKeyCompareHashFirst()) + probe_row.hash = state->hash; + probe_row.key = state->key; + continue; + } + } + + state->stage = ProbePrefetchStage::None; + --active_states; + + setNotMatched(ctx, state->index); + } + + assert(state->stage == ProbePrefetchStage::None); + + while (list_active_slots > 0 && !iter->build_row_ptr) + { + setNotMatched(ctx, iter.getIndex()); + --list_active_slots; + ++iter; + } + + if (list_active_slots > 0) + { + assert(iter != iter_end); + assert(iter->build_row_ptr); + + auto & probe_row = *iter; + PREFETCH_READ(probe_row.build_row_ptr); + state->stage = ProbePrefetchStage::FindNext; + state->is_matched = false; + if constexpr (has_other_eq_from_in_cond) + state->has_null_eq_from_in = probe_row.has_null_eq_from_in; + state->remaining_pace = probe_row.pace; + if constexpr (KeyGetterType::joinKeyCompareHashFirst()) + state->hash = probe_row.hash; + state->key = probe_row.key; + state->ptr = probe_row.build_row_ptr; + ++iter; - probe_list->remove(idx); + --list_active_slots; + ++active_states; + ++k; continue; } - probe_row.build_row_ptr = ptr; - if (current_offset - prev_offset >= probe_row.pace) - probe_row.pace = std::min(MAX_PACE, probe_row.pace * 2U); - if unlikely (current_offset >= settings.max_block_size) - break; - ++iter; + if constexpr (has_null_map) + { + while (idx < ctx.rows) + { + if (!(*ctx.null_map)[idx]) + break; + + setNotMatched(ctx, idx); + ++idx; + } + } + + if unlikely (idx >= ctx.rows) + { + ++k; + continue; + } + + const auto & key = key_getter.getJoinKeyWithBuffer(idx); + auto hash = static_cast(Hash()(key)); + size_t bucket = pointer_table.getBucketNum(hash); + state->pointer_ptr = pointer_table.getPointerTable() + bucket; + PREFETCH_READ(state->pointer_ptr); + + state->key = key; + if constexpr (tagged_pointer) + state->hash_tag = hash & ROW_PTR_TAG_MASK; + if constexpr (KeyGetterType::joinKeyCompareHashFirst()) + state->hash = hash; + state->index = idx; + state->stage = ProbePrefetchStage::FindHeader; + ++active_states; + ++idx; + ++k; } + for (size_t i = 0; i < probe_prefetch_step; ++i) + { + auto * state = &states[i]; + if (state->stage == ProbePrefetchStage::FindNext) + { + auto & probe_row = probe_list->at(state->index); + probe_row.build_row_ptr = state->ptr; + } + } + + flushInsertBatch(wd, added_columns); + fillNullMapWithZero(added_columns); + + ctx.current_row_idx = idx; + ctx.prefetch_active_states = active_states; + ctx.prefetch_iter = k; wd.collision += collision; } diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h index ab5291505eb..f3911bbc2e2 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h @@ -44,16 +44,10 @@ class SemiJoinProbeHelper : public JoinProbeHelperUtil SEMI_JOIN_PROBE_HELPER_TEMPLATE void NO_INLINE probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns); + SEMI_JOIN_PROBE_HELPER_TEMPLATE void NO_INLINE - probeFillColumnsFromList(JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns); - - //SEMI_JOIN_PROBE_HELPER_TEMPLATE - //void NO_INLINE - //probeFillColumnsPrefetch(JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns); - //SEMI_JOIN_PROBE_HELPER_TEMPLATE - //void NO_INLINE - //probeFillColumnsPrefetchFromList(JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns); + probeFillColumnsPrefetch(JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns); private: using FuncType = Block (SemiJoinProbeHelper::*)(JoinProbeContext &, JoinProbeWorkerData &); diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h index 4013b119611..c831c36180c 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h @@ -56,6 +56,14 @@ class SemiJoinProbeList final : public SemiJoinProbeListBase , idx(idx) {} + Iterator(const Iterator & other) = default; + Iterator & operator=(const Iterator & other) + { + assert(&list == &other.list); + idx = other.idx; + return *this; + } + inline IndexType getIndex() const { return idx; } ProbeRow & operator*() const { return list.probe_rows[idx]; } @@ -67,7 +75,11 @@ class SemiJoinProbeList final : public SemiJoinProbeListBase return *this; } - bool operator!=(const Iterator & other) const { return idx != other.idx; } + bool operator!=(const Iterator & other) const + { + assert(&list == &other.list); + return idx != other.idx; + } private: SemiJoinProbeList & list; @@ -106,12 +118,12 @@ class SemiJoinProbeList final : public SemiJoinProbeListBase /// Append an existing slot by index at the tail (before sentinel). inline void append(IndexType idx) { -#ifndef NDEBUG assert(idx < slotCapacity()); +#ifndef NDEBUG assert(probe_rows[idx].prev_idx == idx && probe_rows[idx].next_idx == idx); #endif ++active_count; - IndexType tail = probe_rows[sentinel_idx].prev_idx; + auto tail = probe_rows[sentinel_idx].prev_idx; probe_rows[tail].next_idx = idx; probe_rows[idx].prev_idx = tail; probe_rows[idx].next_idx = sentinel_idx; @@ -121,10 +133,10 @@ class SemiJoinProbeList final : public SemiJoinProbeListBase /// Remove a slot by index from the list. inline void remove(IndexType idx) { -#ifndef NDEBUG assert(idx < slotCapacity()); - assert(probe_rows[idx].prev_idx != idx && probe_rows[idx].next_idx != idx); assert(active_count > 0); +#ifndef NDEBUG + assert(probe_rows[idx].prev_idx != idx && probe_rows[idx].next_idx != idx); #endif --active_count; IndexType prev = probe_rows[idx].prev_idx; @@ -133,6 +145,15 @@ class SemiJoinProbeList final : public SemiJoinProbeListBase probe_rows[next].prev_idx = prev; } + /// Remove a slot by iterator. + inline Iterator remove(Iterator iter) + { + auto idx = iter.getIndex(); + ++iter; + remove(idx); + return iter; + } + Iterator begin() { return Iterator(*this, probe_rows[sentinel_idx].next_idx); } Iterator end() { return Iterator(*this, sentinel_idx); } From 68c3b57497214381343f62bf53a1619d4282a207 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Wed, 30 Apr 2025 02:31:14 +0800 Subject: [PATCH 55/84] u Signed-off-by: gengliqi --- .../Flash/Planner/Plans/PhysicalJoinV2.cpp | 6 +- dbms/src/Interpreters/JoinV2/HashJoinKey.cpp | 3 +- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 30 +- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 4 +- .../src/Interpreters/JoinV2/SemiJoinProbe.cpp | 289 ++++++++++++++++-- dbms/src/Interpreters/JoinV2/SemiJoinProbe.h | 24 +- .../Interpreters/JoinV2/SemiJoinProbeList.cpp | 40 +++ .../Interpreters/JoinV2/SemiJoinProbeList.h | 39 +-- 8 files changed, 372 insertions(+), 63 deletions(-) create mode 100644 dbms/src/Interpreters/JoinV2/SemiJoinProbeList.cpp diff --git a/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp b/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp index 2359be8da8c..cfb5ccab1e0 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp @@ -210,15 +210,11 @@ bool PhysicalJoinV2::isSupported(const tipb::Join & join) { case Inner: case LeftOuter: - if (!tiflash_join.getBuildJoinKeys().empty()) - return true; - break; case Semi: case Anti: case LeftOuterSemi: case LeftOuterAnti: - if (!tiflash_join.getBuildJoinKeys().empty() && join.other_conditions_size() == 0 - && join.other_eq_conditions_from_in_size() == 0) + if (!tiflash_join.getBuildJoinKeys().empty()) return true; break; //case RightOuter: diff --git a/dbms/src/Interpreters/JoinV2/HashJoinKey.cpp b/dbms/src/Interpreters/JoinV2/HashJoinKey.cpp index 0db00123a35..bbe4510ab33 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinKey.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinKey.cpp @@ -79,8 +79,7 @@ std::unique_ptr> createHashJoinKeyGetter( using KeyGetterType##METHOD = typename HashJoinKeyGetterForType::Type; \ return std::unique_ptr>( \ static_cast(new KeyGetterType##METHOD(collators)), \ - [](void * ptr) { delete reinterpret_cast(ptr); }); \ - break; + [](void * ptr) { delete reinterpret_cast(ptr); }); APPLY_FOR_HASH_JOIN_VARIANTS(M) #undef M diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 0c65be2f559..ac9f233681b 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -74,7 +74,7 @@ void JoinProbeContext::prepareForHashProbe( HashJoinKeyMethod method, ASTTableJoin::Kind kind, bool has_other_condition, - bool has_other_eq_from_in_condition, + bool has_other_eq_cond_from_in, const Names & key_names, const String & filter_column, const NameSet & probe_output_name_set, @@ -127,7 +127,7 @@ void JoinProbeContext::prepareForHashProbe( { left_semi_match_res.clear(); left_semi_match_res.resize_fill_zero(rows); - if (has_other_eq_from_in_condition) + if (has_other_eq_cond_from_in) { left_semi_match_null_res.clear(); left_semi_match_null_res.resize_fill_zero(rows); @@ -141,7 +141,9 @@ void JoinProbeContext::prepareForHashProbe( if (SemiJoinProbeHelper::isSupported(kind, has_other_condition)) { - //semi_join_probe_list->reset(rows); + if unlikely (!semi_join_probe_list) + semi_join_probe_list = createSemiJoinProbeList(method); + semi_join_probe_list->reset(rows); } is_prepared = true; @@ -1040,6 +1042,17 @@ Block JoinProbeHelper::handleOtherConditions( non_equal_conditions.other_cond_expr->execute(exec_block); + SCOPE_EXIT({ + RUNTIME_CHECK(wd.result_block.columns() == left_columns + right_columns); + /// Clear the data in result_block. + for (size_t i = 0; i < left_columns + right_columns; ++i) + { + auto column = wd.result_block.getByPosition(i).column->assumeMutable(); + column->popBack(column->size()); + wd.result_block.getByPosition(i).column = std::move(column); + } + }); + size_t rows = exec_block.rows(); // Ensure BASE_OFFSETS is accessed within bound. // It must be true because max_block_size <= BASE_OFFSETS.size(HASH_JOIN_MAX_BLOCK_SIZE_UPPER_BOUND). @@ -1226,17 +1239,6 @@ Block JoinProbeHelper::handleOtherConditions( } }; - SCOPE_EXIT({ - RUNTIME_CHECK(wd.result_block.columns() == left_columns + right_columns); - /// Clear the data in result_block. - for (size_t i = 0; i < left_columns + right_columns; ++i) - { - auto column = wd.result_block.getByPosition(i).column->assumeMutable(); - column->popBack(column->size()); - wd.result_block.getByPosition(i).column = std::move(column); - } - }); - size_t length = std::min(result_size, remaining_insert_size); fill_matched(0, length); if (result_size >= remaining_insert_size) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 3f4846702b7..698f1dba409 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -53,7 +53,7 @@ struct JoinProbeContext /// For (anti) semi join with other conditions. IColumn::Offsets semi_selective_offsets; /// For (left outer) (anti) semi join with other conditions. - std::unique_ptr semi_join_probe_list; + std::unique_ptr semi_join_probe_list; size_t prefetch_active_states = 0; size_t prefetch_iter = 0; @@ -76,7 +76,7 @@ struct JoinProbeContext HashJoinKeyMethod method, ASTTableJoin::Kind kind, bool has_other_condition, - bool has_other_eq_from_in_condition, + bool has_other_eq_cond_from_in, const Names & key_names, const String & filter_column, const NameSet & probe_output_name_set, diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp index a4fab61f345..49fc217ab77 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp @@ -17,6 +17,7 @@ #include "Interpreters/JoinV2/SemiJoinProbeList.h" #include "Parsers/ASTTablesInSelectQuery.h" +#include "common/defines.h" #include "ext/scope_guard.h" #ifdef TIFLASH_ENABLE_AVX_SUPPORT @@ -88,20 +89,20 @@ SemiJoinProbeHelper::SemiJoinProbeHelper(const HashJoin * join) // SemiJoinProbeHelper only handles semi join with other conditions RUNTIME_CHECK(join->has_other_condition); -#define CALL3(KeyGetter, JoinType, has_other_eq_from_in_cond, tagged_pointer) \ +#define CALL3(KeyGetter, JoinType, has_other_eq_cond_from_in, tagged_pointer) \ { \ func_ptr_has_null \ - = &SemiJoinProbeHelper::probeImpl; \ + = &SemiJoinProbeHelper::probeImpl; \ func_ptr_no_null \ - = &SemiJoinProbeHelper::probeImpl; \ + = &SemiJoinProbeHelper::probeImpl; \ } -#define CALL2(KeyGetter, JoinType, has_other_eq_from_in_cond) \ +#define CALL2(KeyGetter, JoinType, has_other_eq_cond_from_in) \ { \ if (pointer_table.enableTaggedPointer()) \ - CALL3(KeyGetter, JoinType, has_other_eq_from_in_cond, true) \ + CALL3(KeyGetter, JoinType, has_other_eq_cond_from_in, true) \ else \ - CALL3(KeyGetter, JoinType, has_other_eq_from_in_cond, false) \ + CALL3(KeyGetter, JoinType, has_other_eq_cond_from_in, false) \ } #define CALL1(KeyGetter, JoinType) \ @@ -165,6 +166,8 @@ Block SemiJoinProbeHelper::probe(JoinProbeContext & ctx, JoinProbeWorkerData & w SEMI_JOIN_PROBE_HELPER_TEMPLATE Block SemiJoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData & wd) { + static_assert(kind == Semi || kind == Anti || kind == LeftOuterAnti || kind == LeftOuterSemi); + if unlikely (ctx.rows == 0) return join->output_block_after_finalize; @@ -176,6 +179,9 @@ Block SemiJoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData RUNTIME_CHECK(ctx.left_semi_match_null_res.size() == ctx.rows); } + wd.selective_offsets.clear(); + wd.selective_offsets.reserve(settings.max_block_size); + size_t left_columns = join->left_sample_block_pruned.columns(); size_t right_columns = join->right_sample_block_pruned.columns(); if (!wd.result_block) @@ -196,14 +202,14 @@ Block SemiJoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData Stopwatch watch; if (pointer_table.enableProbePrefetch()) { - probeFillColumnsPrefetch( + probeFillColumnsPrefetch( ctx, wd, added_columns); } else { - probeFillColumns( + probeFillColumns( ctx, wd, added_columns); @@ -215,7 +221,30 @@ Block SemiJoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData for (size_t i = 0; i < right_columns; ++i) wd.result_block.safeGetByPosition(left_columns + i).column = std::move(added_columns[i]); - if (ctx.isProbeFinished()) {} + if (!wd.selective_offsets.empty()) + { + for (size_t i = 0; i < left_columns; ++i) + { + if (!join->left_required_flag_for_other_condition[i]) + continue; + wd.result_block.safeGetByPosition(i).column->assumeMutable()->insertSelectiveFrom( + *ctx.block.safeGetByPosition(i).column.get(), + wd.selective_offsets); + } + + wd.replicate_time += watch.elapsedFromLastTime(); + + handleOtherConditions(ctx, wd); + } + + if (ctx.isProbeFinished()) + { + if constexpr (kind == Semi || kind == Anti) + return genResultBlockForSemi(ctx); + else + return genResultBlockForLeftOuterSemi(ctx, has_other_eq_cond_from_in); + } + return join->output_block_after_finalize; } @@ -277,7 +306,7 @@ SemiJoinProbeHelper::probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerDat } if (prev_offset == current_offset) { - setNotMatched(ctx, idx, probe_row.has_null_eq_from_in); + setNotMatched(ctx, idx, probe_row.has_null_eq_from_in); iter = probe_list->remove(iter); continue; } @@ -356,7 +385,7 @@ SemiJoinProbeHelper::probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerDat probe_list->append(idx); auto & probe_row = probe_list->at(idx); probe_row.build_row_ptr = ptr; - if constexpr (has_other_eq_from_in_cond) + if constexpr (has_other_eq_cond_from_in) probe_row.has_null_eq_from_in = false; probe_row.pace = std::min(MAX_PACE, INITIAL_PACE * 2U); probe_row.hash = hash; @@ -449,6 +478,7 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( { RowPtr ptr = state->ptr; RowPtr next_ptr = getNextRowPtr(ptr); + state->ptr = next_ptr; const auto & key2 = key_getter.deserializeJoinKey(ptr + key_offset); bool key_is_equal = joinKeyIsEqual(key_getter, state->key, key2, state->hash, ptr); @@ -472,7 +502,7 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( { auto & probe_row = probe_list->at(state->index); probe_row.build_row_ptr = next_ptr; - probe_row.pace = std::min(MAX_PACE, INITIAL_PACE * 2U); + probe_row.pace = std::min(MAX_PACE, probe_row.pace * 2U); remaining_pace_is_zero = true; } } @@ -482,7 +512,6 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( if (next_ptr) { PREFETCH_READ(next_ptr); - state->ptr = next_ptr; ++k; continue; } @@ -490,7 +519,7 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( probe_list->at(state->index).build_row_ptr = next_ptr; if (!state->is_matched) { - setNotMatched(ctx, state->index, state->has_null_eq_from_in); + setNotMatched(ctx, state->index, state->has_null_eq_from_in); probe_list->remove(state->index); } } @@ -516,7 +545,7 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( PREFETCH_READ(ptr); state->stage = ProbePrefetchStage::FindNext; state->is_matched = false; - if constexpr (has_other_eq_from_in_cond) + if constexpr (has_other_eq_cond_from_in) state->has_null_eq_from_in = false; state->remaining_pace = INITIAL_PACE; state->ptr = ptr; @@ -524,8 +553,8 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( probe_list->append(state->index); auto & probe_row = probe_list->at(state->index); - //probe_row.build_row_ptr = ptr; - if constexpr (has_other_eq_from_in_cond) + probe_row.build_row_ptr = ptr; + if constexpr (has_other_eq_cond_from_in) probe_row.has_null_eq_from_in = false; probe_row.pace = INITIAL_PACE; if constexpr (KeyGetterType::joinKeyCompareHashFirst()) @@ -545,9 +574,9 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( while (list_active_slots > 0 && !iter->build_row_ptr) { - setNotMatched(ctx, iter.getIndex()); + setNotMatched(ctx, iter.getIndex(), iter->has_null_eq_from_in); + iter = probe_list->remove(iter); --list_active_slots; - ++iter; } if (list_active_slots > 0) @@ -559,7 +588,7 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( PREFETCH_READ(probe_row.build_row_ptr); state->stage = ProbePrefetchStage::FindNext; state->is_matched = false; - if constexpr (has_other_eq_from_in_cond) + if constexpr (has_other_eq_cond_from_in) state->has_null_eq_from_in = probe_row.has_null_eq_from_in; state->remaining_pace = probe_row.pace; if constexpr (KeyGetterType::joinKeyCompareHashFirst()) @@ -629,4 +658,224 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( wd.collision += collision; } +template +void SemiJoinProbeHelper::handleOtherConditions(JoinProbeContext & ctx, JoinProbeWorkerData & wd) +{ + const auto & left_sample_block_pruned = join->left_sample_block_pruned; + const auto & right_sample_block_pruned = join->right_sample_block_pruned; + const auto & non_equal_conditions = join->non_equal_conditions; + const auto & left_required_flag_for_other_condition = join->left_required_flag_for_other_condition; + + size_t left_columns = left_sample_block_pruned.columns(); + size_t right_columns = right_sample_block_pruned.columns(); + // Some columns in wd.result_block may be empty so need to create another block to execute other condition expressions + Block exec_block; + RUNTIME_CHECK(wd.result_block.columns() == left_columns + right_columns); + for (size_t i = 0; i < left_columns; ++i) + { + if (left_required_flag_for_other_condition[i]) + exec_block.insert(wd.result_block.getByPosition(i)); + } + for (size_t i = 0; i < right_columns; ++i) + exec_block.insert(wd.result_block.getByPosition(left_columns + i)); + + non_equal_conditions.other_cond_expr->execute(exec_block); + + SCOPE_EXIT({ + RUNTIME_CHECK(wd.result_block.columns() == left_columns + right_columns); + /// Clear the data in result_block. + for (size_t i = 0; i < left_columns + right_columns; ++i) + { + auto column = wd.result_block.getByPosition(i).column->assumeMutable(); + column->popBack(column->size()); + wd.result_block.getByPosition(i).column = std::move(column); + } + }); + + const ColumnUInt8::Container *other_eq_from_in_column_data = nullptr, *other_column_data = nullptr; + ConstNullMapPtr other_eq_from_in_null_map = nullptr, other_null_map = nullptr; + ColumnPtr other_eq_from_in_column, other_column; + + bool has_other_eq_cond_from_in = !non_equal_conditions.other_eq_cond_from_in_name.empty(); + if (has_other_eq_cond_from_in) + { + other_eq_from_in_column = exec_block.getByName(non_equal_conditions.other_eq_cond_from_in_name).column; + auto is_nullable_col = [&]() { + if (other_eq_from_in_column->isColumnNullable()) + return true; + if (other_eq_from_in_column->isColumnConst()) + { + const auto & const_col = typeid_cast(*other_eq_from_in_column); + return const_col.getDataColumn().isColumnNullable(); + } + return false; + }; + // nullable, const(nullable) + RUNTIME_CHECK_MSG( + is_nullable_col(), + "The equal condition from in column should be nullable, otherwise it should be used as join key"); + + std::tie(other_eq_from_in_column_data, other_eq_from_in_null_map) + = getDataAndNullMapVectorFromFilterColumn(other_eq_from_in_column); + } + + bool has_other_cond = !non_equal_conditions.other_cond_name.empty(); + bool has_other_cond_null_map = false; + if (has_other_cond) + { + other_column = exec_block.getByName(non_equal_conditions.other_cond_name).column; + std::tie(other_column_data, other_null_map) = getDataAndNullMapVectorFromFilterColumn(other_column); + has_other_cond_null_map = other_null_map != nullptr; + } + +#define CALL(has_other_eq_cond_from_in, has_other_cond, has_other_cond_null_map) \ + { \ + checkExprResults( \ + ctx, \ + wd.selective_offsets, \ + other_eq_from_in_column_data, \ + other_eq_from_in_null_map, \ + other_column_data, \ + other_null_map); \ + } + + if (has_other_eq_cond_from_in) + { + if (has_other_cond) + { + if (has_other_cond_null_map) + CALL(true, true, true) + else + CALL(true, true, false) + } + else + CALL(true, false, false) + } + else + { + RUNTIME_CHECK(has_other_cond); + if (has_other_cond_null_map) + CALL(false, true, true) + else + CALL(false, true, false) + } +#undef CALL +} + +template < + typename KeyGetter, + ASTTableJoin::Kind kind, + bool has_other_eq_cond_from_in, + bool has_other_cond, + bool has_other_cond_null_map> +void SemiJoinProbeHelper::checkExprResults( + JoinProbeContext & ctx, + IColumn::Offsets & selective_offsets, + const ColumnUInt8::Container * other_eq_column, + ConstNullMapPtr other_eq_null_map, + const ColumnUInt8::Container * other_column, + ConstNullMapPtr other_null_map) +{ + static_assert(has_other_cond || has_other_eq_cond_from_in); + auto * probe_list = static_cast *>(ctx.semi_join_probe_list.get()); + size_t sz = selective_offsets.size(); + if constexpr (has_other_eq_cond_from_in) + { + RUNTIME_CHECK(sz == other_eq_column->size()); + RUNTIME_CHECK(sz == other_eq_null_map->size()); + } + if constexpr (has_other_cond) + { + RUNTIME_CHECK(sz == other_column->size()); + if constexpr (has_other_cond_null_map) + RUNTIME_CHECK(sz == other_null_map->size()); + } + for (size_t i = 0; i < sz; ++i) + { + auto index = selective_offsets[i]; + if (!probe_list->contains(index)) + continue; + if constexpr (has_other_cond) + { + if constexpr (has_other_cond_null_map) + { + if ((*other_null_map)[i]) + { + // If other expr is NULL, this row is not included in the result set. + continue; + } + } + if (!(*other_column)[i]) + { + // If other expr is 0, this row is not included in the result set. + continue; + } + } + if constexpr (has_other_eq_cond_from_in) + { + auto & probe_row = probe_list->at(index); + bool is_eq_null = (*other_eq_null_map)[i]; + probe_row.has_null_eq_from_in |= is_eq_null; + if (!is_eq_null && (*other_eq_column)[i]) + { + setMatched(ctx, index); + probe_list->remove(index); + } + } + else + { + // other expr is true, so the result is true for this row that has matched right row(s). + setMatched(ctx, index); + probe_list->remove(index); + } + } +} + +Block SemiJoinProbeHelper::genResultBlockForSemi(JoinProbeContext & ctx) +{ + RUNTIME_CHECK(join->kind == Semi || join->kind == Anti); + RUNTIME_CHECK(ctx.isProbeFinished()); + + Block res_block = join->output_block_after_finalize.cloneEmpty(); + size_t columns = res_block.columns(); + for (size_t i = 0; i < columns; ++i) + { + auto & dst_column = res_block.getByPosition(i); + dst_column.column->assumeMutable()->insertSelectiveFrom( + *ctx.block.getByName(dst_column.name).column.get(), + ctx.semi_selective_offsets); + } + + return res_block; +} + +Block SemiJoinProbeHelper::genResultBlockForLeftOuterSemi(JoinProbeContext & ctx, bool has_other_eq_cond_from_in) +{ + RUNTIME_CHECK(join->kind == LeftOuterSemi || join->kind == LeftOuterAnti); + RUNTIME_CHECK(ctx.isProbeFinished()); + + Block res_block = join->output_block_after_finalize.cloneEmpty(); + size_t columns = res_block.columns(); + size_t match_helper_column_index = res_block.getPositionByName(join->match_helper_name); + for (size_t i = 0; i < columns; ++i) + { + if (i == match_helper_column_index) + continue; + res_block.getByPosition(i) = ctx.block.getByName(res_block.getByPosition(i).name); + } + + MutableColumnPtr match_helper_column_ptr = res_block.getByPosition(match_helper_column_index).column->cloneEmpty(); + auto * match_helper_column = typeid_cast(match_helper_column_ptr.get()); + if (has_other_eq_cond_from_in) + match_helper_column->getNullMapColumn().getData().swap(ctx.left_semi_match_null_res); + else + match_helper_column->getNullMapColumn().getData().resize_fill_zero(ctx.rows); + auto * match_helper_res = &typeid_cast &>(match_helper_column->getNestedColumn()).getData(); + match_helper_res->swap(ctx.left_semi_match_res); + + res_block.getByPosition(match_helper_column_index).column = std::move(match_helper_column_ptr); + + return res_block; +} + } // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h index f3911bbc2e2..9f1526b8ed1 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.h @@ -17,6 +17,8 @@ #include #include +#include "Parsers/ASTTablesInSelectQuery.h" + namespace DB { @@ -25,7 +27,7 @@ namespace DB typename KeyGetter, \ ASTTableJoin::Kind kind, \ bool has_null_map, \ - bool has_other_eq_from_in_cond, \ + bool has_other_eq_cond_from_in, \ bool tagged_pointer> class HashJoin; @@ -49,6 +51,26 @@ class SemiJoinProbeHelper : public JoinProbeHelperUtil void NO_INLINE probeFillColumnsPrefetch(JoinProbeContext & ctx, JoinProbeWorkerData & wd, MutableColumns & added_columns); + template + void handleOtherConditions(JoinProbeContext & ctx, JoinProbeWorkerData & wd); + + template < + typename KeyGetter, + ASTTableJoin::Kind kind, + bool has_other_eq_cond_from_in, + bool has_other_cond, + bool has_other_cond_null_map> + void checkExprResults( + JoinProbeContext & ctx, + IColumn::Offsets & selective_offsets, + const ColumnUInt8::Container * other_eq_column, + ConstNullMapPtr other_eq_null_map, + const ColumnUInt8::Container * other_column, + ConstNullMapPtr other_null_map); + + Block genResultBlockForSemi(JoinProbeContext & ctx); + Block genResultBlockForLeftOuterSemi(JoinProbeContext & ctx, bool has_other_eq_cond_from_in); + private: using FuncType = Block (SemiJoinProbeHelper::*)(JoinProbeContext &, JoinProbeWorkerData &); FuncType func_ptr_has_null = nullptr; diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.cpp b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.cpp new file mode 100644 index 00000000000..a3f678dd7c8 --- /dev/null +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.cpp @@ -0,0 +1,40 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +namespace DB +{ + +std::unique_ptr createSemiJoinProbeList(HashJoinKeyMethod method) +{ + switch (method) + { +#define M(METHOD) \ + case HashJoinKeyMethod::METHOD: \ + using KeyGetterType##METHOD = HashJoinKeyGetterForType; \ + return std::make_unique>(); + APPLY_FOR_HASH_JOIN_VARIANTS(M) +#undef M + + default: + throw Exception( + fmt::format("Unknown JOIN keys variant {}.", magic_enum::enum_name(method)), + ErrorCodes::UNKNOWN_SET_DATA_VARIANT); + } +} + +} // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h index c831c36180c..248a575cb72 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h @@ -15,22 +15,24 @@ #pragma once #include +#include #include namespace DB { -class SemiJoinProbeListBase +class ISemiJoinProbeList { public: - virtual ~SemiJoinProbeListBase() = 0; + virtual ~ISemiJoinProbeList() = default; + virtual void reset(size_t n) = 0; virtual size_t activeSlots() const = 0; }; /// A reusable, index‑based doubly‑linked circular list for managing semi join pending probe rows. /// Supports O(1) append/remove by index. template -class SemiJoinProbeList final : public SemiJoinProbeListBase +class SemiJoinProbeList final : public ISemiJoinProbeList { public: using IndexType = UInt32; @@ -89,7 +91,7 @@ class SemiJoinProbeList final : public SemiJoinProbeListBase SemiJoinProbeList() = default; /// After reset(n), it holds n entries plus a sentinel at index n. - void reset(size_t n) + void reset(size_t n) override { // reset should be called after all active slots are removed RUNTIME_CHECK(active_count == 0); @@ -99,15 +101,6 @@ class SemiJoinProbeList final : public SemiJoinProbeListBase // Initialize sentinel self-loop probe_rows[sentinel_idx].prev_idx = sentinel_idx; probe_rows[sentinel_idx].next_idx = sentinel_idx; - -#ifndef NDEBUG - // Isolate all slots - for (IndexType i = 0; i < n; ++i) - { - probe_rows[i].prev_idx = i; - probe_rows[i].next_idx = i; - } -#endif } /// Returns the number of usable slots in the list (excluding the sentinel). @@ -115,13 +108,19 @@ class SemiJoinProbeList final : public SemiJoinProbeListBase size_t activeSlots() const override { return active_count; } + /// Returns true if the slot is currently linked in the list. + inline bool contains(IndexType idx) + { + assert(idx < slotCapacity()); + return probe_rows[idx].in_list; + } + /// Append an existing slot by index at the tail (before sentinel). inline void append(IndexType idx) { assert(idx < slotCapacity()); -#ifndef NDEBUG - assert(probe_rows[idx].prev_idx == idx && probe_rows[idx].next_idx == idx); -#endif + assert(!contains(idx)); + probe_rows[idx].in_list = true; ++active_count; auto tail = probe_rows[sentinel_idx].prev_idx; probe_rows[tail].next_idx = idx; @@ -134,10 +133,9 @@ class SemiJoinProbeList final : public SemiJoinProbeListBase inline void remove(IndexType idx) { assert(idx < slotCapacity()); + assert(contains(idx)); + probe_rows[idx].in_list = false; assert(active_count > 0); -#ifndef NDEBUG - assert(probe_rows[idx].prev_idx != idx && probe_rows[idx].next_idx != idx); -#endif --active_count; IndexType prev = probe_rows[idx].prev_idx; IndexType next = probe_rows[idx].next_idx; @@ -171,6 +169,7 @@ class SemiJoinProbeList final : public SemiJoinProbeListBase private: struct WrapProbeRow : ProbeRow { + bool in_list; /// Embedded list indexes IndexType prev_idx; IndexType next_idx; @@ -181,4 +180,6 @@ class SemiJoinProbeList final : public SemiJoinProbeListBase size_t active_count = 0; }; +std::unique_ptr createSemiJoinProbeList(HashJoinKeyMethod method); + } // namespace DB From 1bdeace2a6de23cebe0264307720264fcc51fae2 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Wed, 30 Apr 2025 02:35:49 +0800 Subject: [PATCH 56/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp index 49fc217ab77..6c03d10b8a4 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp @@ -248,7 +248,7 @@ Block SemiJoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData return join->output_block_after_finalize; } -static constexpr UInt16 INITIAL_PACE = 4; +static constexpr UInt16 INITIAL_PACE = 2; static constexpr UInt16 MAX_PACE = 8192; SEMI_JOIN_PROBE_HELPER_TEMPLATE From 537c1238355343ee1ede6c228a64cb3c75e35119 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Wed, 30 Apr 2025 14:21:50 +0800 Subject: [PATCH 57/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h index 248a575cb72..2a821bd52de 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h @@ -101,6 +101,8 @@ class SemiJoinProbeList final : public ISemiJoinProbeList // Initialize sentinel self-loop probe_rows[sentinel_idx].prev_idx = sentinel_idx; probe_rows[sentinel_idx].next_idx = sentinel_idx; + for (size_t i = 0; i < n; ++i) + probe_rows[i].in_list = false; } /// Returns the number of usable slots in the list (excluding the sentinel). From dd986a5be23d40abe64135baa57d8ab2384db846 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 1 May 2025 00:36:05 +0800 Subject: [PATCH 58/84] fix Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp index 6c03d10b8a4..4a616255898 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp @@ -516,7 +516,7 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( continue; } - probe_list->at(state->index).build_row_ptr = next_ptr; + probe_list->at(state->index).build_row_ptr = nullptr; if (!state->is_matched) { setNotMatched(ctx, state->index, state->has_null_eq_from_in); @@ -591,6 +591,7 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( if constexpr (has_other_eq_cond_from_in) state->has_null_eq_from_in = probe_row.has_null_eq_from_in; state->remaining_pace = probe_row.pace; + state->index = iter.getIndex(); if constexpr (KeyGetterType::joinKeyCompareHashFirst()) state->hash = probe_row.hash; state->key = probe_row.key; From fcbb4008136959aa0a6569e46386f7d6131491d1 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 1 May 2025 02:41:17 +0800 Subject: [PATCH 59/84] fix bug Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 8 ++++++++ dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp | 6 ++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index ac9f233681b..b3c71247b9e 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -655,7 +655,11 @@ void JoinProbeHelper::probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerDa using Adder = JoinProbeAdder; auto & key_getter = *static_cast(ctx.key_getter.get()); + // Some columns in wd.result_block may remain empty due to late materialization for join with other conditions. + // But since all columns are cleared after handling other conditions, wd.result_block.rows() is always 0. size_t current_offset = wd.result_block.rows(); + if constexpr (has_other_condition) + RUNTIME_CHECK(current_offset == 0); size_t idx = ctx.current_row_idx; RowPtr ptr = ctx.current_build_row_ptr; bool is_matched = ctx.current_row_is_matched; @@ -833,7 +837,11 @@ void JoinProbeHelper::probeFillColumnsPrefetch( size_t idx = ctx.current_row_idx; size_t active_states = ctx.prefetch_active_states; size_t k = ctx.prefetch_iter; + // Some columns in wd.result_block may remain empty due to late materialization for join with other conditions. + // But since all columns are cleared after handling other conditions, wd.result_block.rows() is always 0. size_t current_offset = wd.result_block.rows(); + if constexpr (has_other_condition) + RUNTIME_CHECK(current_offset == 0); size_t collision = 0; constexpr size_t key_offset = sizeof(RowPtr) + (KeyGetterType::joinKeyCompareHashFirst() ? sizeof(HashValueType) : 0); diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp index 4a616255898..f603678770d 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp @@ -262,7 +262,7 @@ SemiJoinProbeHelper::probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerDat auto & key_getter = *static_cast(ctx.key_getter.get()); auto * probe_list = static_cast *>(ctx.semi_join_probe_list.get()); RUNTIME_CHECK(probe_list->slotCapacity() == ctx.rows); - size_t current_offset = wd.result_block.rows(); + size_t current_offset = 0; size_t collision = 0; size_t key_offset = sizeof(RowPtr); if constexpr (KeyGetterType::joinKeyCompareHashFirst()) @@ -462,7 +462,7 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( size_t idx = ctx.current_row_idx; size_t active_states = ctx.prefetch_active_states; size_t k = ctx.prefetch_iter; - size_t current_offset = wd.result_block.rows(); + size_t current_offset = 0; size_t collision = 0; constexpr size_t key_offset = sizeof(RowPtr) + (KeyGetterType::joinKeyCompareHashFirst() ? sizeof(HashValueType) : 0); @@ -647,6 +647,8 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( { auto & probe_row = probe_list->at(state->index); probe_row.build_row_ptr = state->ptr; + state->stage = ProbePrefetchStage::None; + --active_states; } } From 55f66a1dbb1735916292d63ae6068494066c81bf Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 1 May 2025 12:11:18 +0800 Subject: [PATCH 60/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp | 2 +- dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp index f603678770d..dec14b3a932 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp @@ -248,7 +248,7 @@ Block SemiJoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData return join->output_block_after_finalize; } -static constexpr UInt16 INITIAL_PACE = 2; +static constexpr UInt16 INITIAL_PACE = 4; static constexpr UInt16 MAX_PACE = 8192; SEMI_JOIN_PROBE_HELPER_TEMPLATE diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h index 2a821bd52de..43bfbdbb009 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h @@ -101,8 +101,8 @@ class SemiJoinProbeList final : public ISemiJoinProbeList // Initialize sentinel self-loop probe_rows[sentinel_idx].prev_idx = sentinel_idx; probe_rows[sentinel_idx].next_idx = sentinel_idx; - for (size_t i = 0; i < n; ++i) - probe_rows[i].in_list = false; + probe_rows_in_list.clear(); + probe_rows_in_list.resize_fill_zero(n); } /// Returns the number of usable slots in the list (excluding the sentinel). @@ -114,7 +114,7 @@ class SemiJoinProbeList final : public ISemiJoinProbeList inline bool contains(IndexType idx) { assert(idx < slotCapacity()); - return probe_rows[idx].in_list; + return probe_rows_in_list[idx]; } /// Append an existing slot by index at the tail (before sentinel). @@ -122,7 +122,7 @@ class SemiJoinProbeList final : public ISemiJoinProbeList { assert(idx < slotCapacity()); assert(!contains(idx)); - probe_rows[idx].in_list = true; + probe_rows_in_list[idx] = true; ++active_count; auto tail = probe_rows[sentinel_idx].prev_idx; probe_rows[tail].next_idx = idx; @@ -136,7 +136,7 @@ class SemiJoinProbeList final : public ISemiJoinProbeList { assert(idx < slotCapacity()); assert(contains(idx)); - probe_rows[idx].in_list = false; + probe_rows_in_list[idx] = false; assert(active_count > 0); --active_count; IndexType prev = probe_rows[idx].prev_idx; @@ -171,13 +171,13 @@ class SemiJoinProbeList final : public ISemiJoinProbeList private: struct WrapProbeRow : ProbeRow { - bool in_list; /// Embedded list indexes IndexType prev_idx; IndexType next_idx; }; PaddedPODArray probe_rows; + PaddedPODArray probe_rows_in_list; IndexType sentinel_idx = 0; size_t active_count = 0; }; From a35c7412805485844dd0318b41d044a38e6d9ec2 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 6 May 2025 12:14:30 +0800 Subject: [PATCH 61/84] u Signed-off-by: gengliqi --- dbms/src/Columns/filterColumn.cpp | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/dbms/src/Columns/filterColumn.cpp b/dbms/src/Columns/filterColumn.cpp index 35c0eb06fcb..533c3672d8b 100644 --- a/dbms/src/Columns/filterColumn.cpp +++ b/dbms/src/Columns/filterColumn.cpp @@ -287,7 +287,7 @@ INSTANTIATE(Float64) #undef INSTANTIATE template -void filterImpl(const UInt8 * filt_pos, const UInt8 * filt_end, const T * data_pos, Container & res_data) +void filterImpl(const UInt8 * filt_pos, const UInt8 * filt_end, T const * data_pos, Container & res_data) { const UInt8 * filt_end_aligned = filt_pos + (filt_end - filt_pos) / FILTER_SIMD_BYTES * FILTER_SIMD_BYTES; filterImplAligned(filt_pos, filt_end_aligned, data_pos, res_data); @@ -307,7 +307,7 @@ void filterImpl(const UInt8 * filt_pos, const UInt8 * filt_end, const T * data_p template void filterImpl( \ const UInt8 * filt_pos, \ const UInt8 * filt_end, \ - const T * data_pos, \ + T const * data_pos, \ Container & res_data); // NOLINT INSTANTIATE(UInt8, PaddedPODArray) @@ -322,16 +322,11 @@ INSTANTIATE(Int64, PaddedPODArray) INSTANTIATE(Int128, PaddedPODArray) INSTANTIATE(Float32, PaddedPODArray) INSTANTIATE(Float64, PaddedPODArray) +INSTANTIATE(char *, PaddedPODArray) INSTANTIATE(Decimal32, DecimalPaddedPODArray) INSTANTIATE(Decimal64, DecimalPaddedPODArray) INSTANTIATE(Decimal128, DecimalPaddedPODArray) INSTANTIATE(Decimal256, DecimalPaddedPODArray) -// Cannot use INSTANTIATE micro because `const T * data_pos` + `T: char *` will be intepreted as `const char **` -template void filterImpl>( - const UInt8 * filt_pos, - const UInt8 * filt_end, - char * const * data_pos, - PaddedPODArray & res_data); #undef INSTANTIATE From 7c29983bafa8fa109165014227dec9a2d0e1e487 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Wed, 7 May 2025 13:25:21 +0800 Subject: [PATCH 62/84] fix Signed-off-by: gengliqi --- contrib/simsimd | 2 +- contrib/usearch | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/contrib/simsimd b/contrib/simsimd index ff51434d90c..b672abb945c 160000 --- a/contrib/simsimd +++ b/contrib/simsimd @@ -1 +1 @@ -Subproject commit ff51434d90c66f916e94ff05b24530b127aa4cff +Subproject commit b672abb945ce47c7ae7f6e4b7e2e2e818ee0244f diff --git a/contrib/usearch b/contrib/usearch index 5ad2053521a..e414f46b440 160000 --- a/contrib/usearch +++ b/contrib/usearch @@ -1 +1 @@ -Subproject commit 5ad2053521ab432cd13e236d1d4e7788479a011b +Subproject commit e414f46b440fc2b989aca648fecfc4db35e22332 From efbf062afaf92881eba7ce61a812dbe7d748d7d8 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Wed, 7 May 2025 13:29:54 +0800 Subject: [PATCH 63/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h index 43bfbdbb009..b2d1c6d5dfe 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h @@ -101,8 +101,8 @@ class SemiJoinProbeList final : public ISemiJoinProbeList // Initialize sentinel self-loop probe_rows[sentinel_idx].prev_idx = sentinel_idx; probe_rows[sentinel_idx].next_idx = sentinel_idx; - probe_rows_in_list.clear(); - probe_rows_in_list.resize_fill_zero(n); + active_rows_in_list.clear(); + active_rows_in_list.resize_fill_zero(n); } /// Returns the number of usable slots in the list (excluding the sentinel). @@ -114,7 +114,7 @@ class SemiJoinProbeList final : public ISemiJoinProbeList inline bool contains(IndexType idx) { assert(idx < slotCapacity()); - return probe_rows_in_list[idx]; + return active_rows_in_list[idx]; } /// Append an existing slot by index at the tail (before sentinel). @@ -122,7 +122,7 @@ class SemiJoinProbeList final : public ISemiJoinProbeList { assert(idx < slotCapacity()); assert(!contains(idx)); - probe_rows_in_list[idx] = true; + active_rows_in_list[idx] = true; ++active_count; auto tail = probe_rows[sentinel_idx].prev_idx; probe_rows[tail].next_idx = idx; @@ -136,11 +136,11 @@ class SemiJoinProbeList final : public ISemiJoinProbeList { assert(idx < slotCapacity()); assert(contains(idx)); - probe_rows_in_list[idx] = false; + active_rows_in_list[idx] = false; assert(active_count > 0); --active_count; - IndexType prev = probe_rows[idx].prev_idx; - IndexType next = probe_rows[idx].next_idx; + auto prev = probe_rows[idx].prev_idx; + auto next = probe_rows[idx].next_idx; probe_rows[prev].next_idx = next; probe_rows[next].prev_idx = prev; } @@ -177,7 +177,7 @@ class SemiJoinProbeList final : public ISemiJoinProbeList }; PaddedPODArray probe_rows; - PaddedPODArray probe_rows_in_list; + PaddedPODArray active_rows_in_list; IndexType sentinel_idx = 0; size_t active_count = 0; }; From 7a808f8822a81505198c66fb4cfdfd64551d9118 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 8 May 2025 01:52:18 +0800 Subject: [PATCH 64/84] fix Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp index 988dd3c163b..02f9cb8f797 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp @@ -466,7 +466,7 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( size_t list_active_slots = probe_list->activeSlots(); auto iter = probe_list->begin(); - auto iter_end = probe_list->end(); + auto iter_end [[maybe_unused]] = probe_list->end(); while (idx < ctx.rows || active_states > 0 || list_active_slots > 0) { k = k == probe_prefetch_step ? 0 : k; From 74fb0564195dbd6cb4b6a643baa65e1700f68b87 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 15 May 2025 12:21:03 +0800 Subject: [PATCH 65/84] add test for SemiJoinProbeList Signed-off-by: gengliqi --- .../Interpreters/JoinV2/SemiJoinProbeList.h | 2 +- .../JoinV2/gtest_semi_join_probe_list.cpp | 71 +++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 dbms/src/Interpreters/tests/JoinV2/gtest_semi_join_probe_list.cpp diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h index b2d1c6d5dfe..ca39d434660 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbeList.h @@ -29,7 +29,7 @@ class ISemiJoinProbeList virtual size_t activeSlots() const = 0; }; -/// A reusable, index‑based doubly‑linked circular list for managing semi join pending probe rows. +/// A reusable, index‑based, doubly‑linked circular list for managing semi join pending probe rows. /// Supports O(1) append/remove by index. template class SemiJoinProbeList final : public ISemiJoinProbeList diff --git a/dbms/src/Interpreters/tests/JoinV2/gtest_semi_join_probe_list.cpp b/dbms/src/Interpreters/tests/JoinV2/gtest_semi_join_probe_list.cpp new file mode 100644 index 00000000000..e0469b42385 --- /dev/null +++ b/dbms/src/Interpreters/tests/JoinV2/gtest_semi_join_probe_list.cpp @@ -0,0 +1,71 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + + +namespace DB +{ +namespace tests +{ + +class SemiJoinProbeListTest : public ::testing::Test +{ +}; + +TEST_F(SemiJoinProbeListTest, TestRandom) +try +{ + // The KeyGetter type is unrelated to the functionality being tested. + SemiJoinProbeList> list; + std::random_device rd; + std::mt19937 g(rd()); + std::uniform_int_distribution dist; + + size_t n = dist(g) % 1000 + 1000; + + list.reset(n); + EXPECT_EQ(list.slotCapacity(), n); + EXPECT_EQ(list.activeSlots(), 0); + std::unordered_set s1, s2; + for (size_t i = 0; i < n; ++i) + s1.insert(i); + while (!s1.empty() && !s2.empty()) + { + EXPECT_EQ(list.activeSlots(), s1.size()); + bool is_append = !s1.empty() && dist(g) % 2 == 0; + if (is_append) + { + size_t append_idx = *s1.begin(); + EXPECT_TRUE(!list.contains(append_idx)); + s1.erase(append_idx); + s2.insert(append_idx); + list.append(append_idx); + continue; + } + size_t remove_idx = *s2.begin(); + EXPECT_TRUE(list.contains(remove_idx)); + s2.erase(remove_idx); + list.remove(remove_idx); + } + EXPECT_EQ(list.slotCapacity(), n); + EXPECT_EQ(list.activeSlots(), 0); +} +CATCH + +} // namespace tests +} // namespace DB \ No newline at end of file From 3eebd15435854ca178def55245ed359ec1af7d2d Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 15 May 2025 13:50:10 +0800 Subject: [PATCH 66/84] format Signed-off-by: gengliqi --- .../Interpreters/tests/JoinV2/gtest_semi_join_probe_list.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dbms/src/Interpreters/tests/JoinV2/gtest_semi_join_probe_list.cpp b/dbms/src/Interpreters/tests/JoinV2/gtest_semi_join_probe_list.cpp index e0469b42385..f9f48e56be9 100644 --- a/dbms/src/Interpreters/tests/JoinV2/gtest_semi_join_probe_list.cpp +++ b/dbms/src/Interpreters/tests/JoinV2/gtest_semi_join_probe_list.cpp @@ -14,6 +14,7 @@ #include #include + #include #include @@ -47,7 +48,7 @@ try while (!s1.empty() && !s2.empty()) { EXPECT_EQ(list.activeSlots(), s1.size()); - bool is_append = !s1.empty() && dist(g) % 2 == 0; + bool is_append = !s1.empty() && dist(g) % 2 == 0; if (is_append) { size_t append_idx = *s1.begin(); @@ -68,4 +69,4 @@ try CATCH } // namespace tests -} // namespace DB \ No newline at end of file +} // namespace DB From 67d1b16a66b4499c197cdeb00e3aca668d714398 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Fri, 6 Jun 2025 10:47:48 +0800 Subject: [PATCH 67/84] support right outer/semi/anti join Signed-off-by: gengliqi --- .../Flash/Planner/Plans/PhysicalJoinV2.cpp | 6 +- dbms/src/Flash/tests/gtest_join_executor.cpp | 2 + dbms/src/Interpreters/JoinV2/HashJoin.cpp | 69 ++-- dbms/src/Interpreters/JoinV2/HashJoin.h | 19 + .../src/Interpreters/JoinV2/HashJoinBuild.cpp | 79 ++++- dbms/src/Interpreters/JoinV2/HashJoinBuild.h | 9 +- .../JoinV2/HashJoinPointerTable.cpp | 19 +- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 203 +++++++++-- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 13 +- .../JoinV2/HashJoinProbeBuildScanner.cpp | 331 ++++++++++++++++++ .../JoinV2/HashJoinProbeBuildScanner.h | 68 ++++ .../Interpreters/JoinV2/HashJoinRowLayout.h | 51 ++- .../src/Interpreters/JoinV2/SemiJoinProbe.cpp | 14 +- .../Operators/HashJoinV2ProbeTransformOp.cpp | 60 +++- .../Operators/HashJoinV2ProbeTransformOp.h | 9 + 15 files changed, 836 insertions(+), 116 deletions(-) create mode 100644 dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp create mode 100644 dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.h diff --git a/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp b/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp index cfb5ccab1e0..08a691336e8 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalJoinV2.cpp @@ -214,12 +214,12 @@ bool PhysicalJoinV2::isSupported(const tipb::Join & join) case Anti: case LeftOuterSemi: case LeftOuterAnti: + case RightOuter: + case RightSemi: + case RightAnti: if (!tiflash_join.getBuildJoinKeys().empty()) return true; break; - //case RightOuter: - //case RightSemi: - //case RightAnti: default: } return false; diff --git a/dbms/src/Flash/tests/gtest_join_executor.cpp b/dbms/src/Flash/tests/gtest_join_executor.cpp index 3f5ddcc5143..e0169501077 100644 --- a/dbms/src/Flash/tests/gtest_join_executor.cpp +++ b/dbms/src/Flash/tests/gtest_join_executor.cpp @@ -4795,6 +4795,7 @@ try auto request = context.scan("right_semi_family", "t") .join(context.scan("right_semi_family", "s"), type, {col("a")}, {}, {}, {}, {}, 0, false, 0) .build(context); + WRAP_FOR_JOIN_TEST_BEGIN executeAndAssertColumnsEqual(request, res); auto request_column_prune = context.scan("right_semi_family", "t") @@ -4802,6 +4803,7 @@ try .aggregation({Count(lit(static_cast(1)))}, {}) .build(context); ASSERT_COLUMNS_EQ_UR(genScalarCountResults(res), executeStreams(request_column_prune, 2)); + WRAP_FOR_JOIN_TEST_END } /// One join key(t.a = s.a) + other condition(t.c < s.c). diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index 305be3a2c99..9c74a1d9de0 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -133,6 +133,7 @@ HashJoin::HashJoin( , log(Logger::get(join_req_id)) , has_other_condition(non_equal_conditions.other_cond_expr != nullptr) , output_columns(output_columns_) + , wait_probe_finished_future(std::make_shared(NotifyType::WAIT_ON_JOIN_PROBE_FINISH)) { RUNTIME_ASSERT(key_names_left.size() == key_names_right.size()); output_block = Block(output_columns); @@ -261,6 +262,7 @@ void HashJoin::initRowLayoutAndHashJoinMethod() if (c.column->valuesHaveFixedSize()) { row_layout.other_column_fixed_size += c.column->sizeOfValueIfFixed(); + row_layout.other_column_for_other_condition_fixed_size += c.column->sizeOfValueIfFixed(); row_layout.other_column_indexes.push_back({i, true}); } else @@ -455,11 +457,24 @@ bool HashJoin::finishOneProbe(size_t stream_index) if (active_probe_worker.fetch_sub(1) == 1) { FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_mpp_hash_probe); + if (isRightOuterJoin(kind) || isRightSemiFamily(kind)) + join_probe_build_scanner = std::make_unique(this); + wait_probe_finished_future->finish(); return true; } return false; } +bool HashJoin::isAllProbeFinished() +{ + if (active_probe_worker > 0) + { + setNotifyFuture(wait_probe_finished_future.get()); + return false; + } + return true; +} + void HashJoin::workAfterBuildRowFinish() { size_t all_build_row_count = 0; @@ -478,13 +493,15 @@ void HashJoin::workAfterBuildRowFinish() enable_tagged_pointer, false); - /// Conservative threshold: trigger late materialization when lm_row_size average >= 16 bytes. - constexpr size_t trigger_lm_row_size_threshold = 16; bool late_materialization = false; size_t avg_lm_row_size = 0; - if (has_other_condition - && row_layout.other_column_count_for_other_condition < row_layout.other_column_indexes.size()) + if (shouldCheckLateMaterialization()) { + // Calculate the average row size of late materialization rows. + // If the average row size is greater than or equal to the threshold, enable late materialization. + // Otherwise, disable it. + // Note: this is a conservative threshold, enable late materialization when lm_row_size average >= 16 bytes. + constexpr size_t trigger_lm_row_size_threshold = 16; size_t total_lm_row_size = 0; size_t total_lm_row_count = 0; for (size_t i = 0; i < build_concurrency; ++i) @@ -556,11 +573,9 @@ void HashJoin::buildRowFromBlock(const Block & b, size_t stream_index) assertBlocksHaveEqualStructure(block, right_sample_block_pruned, "Join Build"); - bool check_lm_row_size = has_other_condition - && row_layout.other_column_count_for_other_condition < row_layout.other_column_indexes.size(); insertBlockToRowContainers( method, - needRecordNotInsertRows(kind), + kind, block, rows, key_columns, @@ -568,7 +583,7 @@ void HashJoin::buildRowFromBlock(const Block & b, size_t stream_index) row_layout, multi_row_containers, build_workers_data[stream_index], - check_lm_row_size); + shouldCheckLateMaterialization()); build_workers_data[stream_index].build_time += watch.elapsedMilliseconds(); } @@ -576,21 +591,20 @@ void HashJoin::buildRowFromBlock(const Block & b, size_t stream_index) bool HashJoin::buildPointerTable(size_t stream_index) { bool is_end; + size_t max_build_size = 2 * settings.max_block_size; switch (method) { -#define M(METHOD) \ - case HashJoinKeyMethod::METHOD: \ - using KeyGetterType##METHOD = HashJoinKeyGetterForType; \ - if constexpr (KeyGetterType##METHOD::Type::joinKeyCompareHashFirst()) \ - is_end = pointer_table.build( \ - build_workers_data[stream_index], \ - multi_row_containers, \ - settings.max_block_size); \ - else \ - is_end = pointer_table.build( \ - build_workers_data[stream_index], \ - multi_row_containers, \ - settings.max_block_size); \ +#define M(METHOD) \ + case HashJoinKeyMethod::METHOD: \ + using KeyGetterType##METHOD = HashJoinKeyGetterForType; \ + if constexpr (KeyGetterType##METHOD::Type::joinKeyCompareHashFirst()) \ + is_end = pointer_table.build( \ + build_workers_data[stream_index], \ + multi_row_containers, \ + max_build_size); \ + else \ + is_end \ + = pointer_table.build(build_workers_data[stream_index], multi_row_containers, max_build_size); \ break; APPLY_FOR_HASH_JOIN_VARIANTS(M) #undef M @@ -665,6 +679,19 @@ Block HashJoin::probeLastResultBlock(size_t stream_index) return {}; } +bool HashJoin::needProbeScanBuildSide() const +{ + return join_probe_build_scanner != nullptr; +} + +Block HashJoin::probeScanBuildSide(size_t stream_index) +{ + auto & wd = probe_workers_data[stream_index]; + Stopwatch all_watch; + SCOPE_EXIT({ probe_workers_data[stream_index].scan_build_side_time += all_watch.elapsedFromLastTime(); }); + return join_probe_build_scanner->scan(wd); +} + void HashJoin::removeUselessColumn(Block & block) const { const NameSet & probe_output_name_set = has_other_condition diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.h b/dbms/src/Interpreters/JoinV2/HashJoin.h index e28292e6a61..a23967fbace 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.h +++ b/dbms/src/Interpreters/JoinV2/HashJoin.h @@ -19,11 +19,14 @@ #include #include #include +#include #include +#include #include #include #include #include +#include #include #include #include @@ -54,6 +57,8 @@ class HashJoin bool finishOneBuildRow(size_t stream_index); /// Return true if it is the last probe worker. bool finishOneProbe(size_t stream_index); + /// Return true if all probe work has finished. + bool isAllProbeFinished(); void buildRowFromBlock(const Block & block, size_t stream_index); bool buildPointerTable(size_t stream_index); @@ -61,6 +66,9 @@ class HashJoin Block probeBlock(JoinProbeContext & ctx, size_t stream_index); Block probeLastResultBlock(size_t stream_index); + bool needProbeScanBuildSide() const; + Block probeScanBuildSide(size_t stream_index); + void removeUselessColumn(Block & block) const; /// Block's schema must be all_sample_block_pruned. Block removeUselessColumnForOutput(const Block & block) const; @@ -81,9 +89,17 @@ class HashJoin void workAfterBuildRowFinish(); + bool shouldCheckLateMaterialization() const + { + bool is_any_semi_join = isSemiFamily(kind) || isLeftOuterSemiFamily(kind) || isRightSemiFamily(kind); + return has_other_condition && !is_any_semi_join + && row_layout.other_column_count_for_other_condition < row_layout.other_column_indexes.size(); + } + private: friend JoinProbeHelper; friend SemiJoinProbeHelper; + friend JoinProbeBuildScanner; static const DataTypePtr match_helper_type; @@ -152,8 +168,11 @@ class HashJoin size_t probe_concurrency = 0; std::vector probe_workers_data; std::atomic active_probe_worker = 0; + OneTimeNotifyFuturePtr wait_probe_finished_future; std::unique_ptr join_probe_helper; std::unique_ptr semi_join_probe_helper; + /// Probe scan build side + std::unique_ptr join_probe_build_scanner; const JoinProfileInfoPtr profile_info = std::make_shared(); diff --git a/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp b/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp index 337b24e0500..1d72afcad78 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp @@ -15,6 +15,9 @@ #include #include +#include "Interpreters/JoinUtils.h" +#include "Parsers/ASTTablesInSelectQuery.h" + namespace DB { namespace ErrorCodes @@ -25,6 +28,7 @@ extern const int UNKNOWN_SET_DATA_VARIANT; template void NO_INLINE insertBlockToRowContainersTypeImpl( + ASTTableJoin::Kind kind, Block & block, size_t rows, const ColumnRawPtrs & key_columns, @@ -43,7 +47,16 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( key_getter.reset(key_columns, row_layout.raw_key_column_indexes.size()); wd.row_sizes.clear(); - wd.row_sizes.resize_fill(rows, row_layout.other_column_fixed_size); + bool is_right_semi_family = isRightSemiFamily(kind); + if (is_right_semi_family) + { + wd.row_sizes.resize_fill(rows, row_layout.other_column_for_other_condition_fixed_size); + wd.right_semi_selector.resize(rows); + } + else + { + wd.row_sizes.resize_fill(rows, row_layout.other_column_fixed_size); + } wd.hashes.resize(rows); /// The last partition is used to hold rows with null join key. constexpr size_t part_count = JOIN_BUILD_PARTITION_COUNT + 1; @@ -66,8 +79,15 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( } } - for (const auto & [index, is_fixed_size] : row_layout.other_column_indexes) + size_t other_column_count; + if (is_right_semi_family) + other_column_count = row_layout.other_column_count_for_other_condition; + else + other_column_count = row_layout.other_column_indexes.size(); + for (size_t i = 0; i < other_column_count; ++i) { + size_t index = row_layout.other_column_indexes[i].first; + bool is_fixed_size = row_layout.other_column_indexes[i].second; if (!is_fixed_size) block.getByPosition(index).column->countSerializeByteSize(wd.row_sizes); } @@ -76,8 +96,12 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( { if (has_null_map && (*null_map)[i]) { + // TODO: the non-key row does not needed for right semi join. However, IColumn::scatterTo must need a selector for all rows + if (is_right_semi_family) + wd.right_semi_selector[i] = part_count - 1; if constexpr (need_record_null_rows) { + RUNTIME_CHECK_MSG(false, "need_record_null_rows is not supported yet"); //TODO //wd.row_sizes[i] += sizeof(RowPtr); //wd.row_sizes[i] = alignRowSize(wd.row_sizes[i]); @@ -89,6 +113,8 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( const auto & key = key_getter.getJoinKeyWithBuffer(i); wd.hashes[i] = static_cast(Hash()(key)); size_t part_num = getJoinBuildPartitionNum(wd.hashes[i]); + if (is_right_semi_family) + wd.right_semi_selector[i] = part_num; size_t ptr_and_key_size = sizeof(RowPtr) + key_getter.getJoinKeyByteSize(key); if constexpr (KeyGetterType::joinKeyCompareHashFirst()) @@ -126,12 +152,12 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( ++wd.partition_row_count[part_num]; } - std::vector partition_column_row(part_count); + std::array partition_row_container; for (size_t i = 0; i < part_count; ++i) { if (wd.partition_row_count[i] > 0) { - auto & container = partition_column_row[i]; + auto & container = partition_row_container[i]; container.data.resize(wd.partition_row_sizes[i], CPU_CACHE_LINE_SIZE); wd.enable_tagged_pointer &= isRowPtrTagZero(container.data.data()); wd.enable_tagged_pointer &= isRowPtrTagZero(container.data.data() + wd.partition_row_sizes[i]); @@ -174,12 +200,12 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( continue; } size_t part_num = getJoinBuildPartitionNum(wd.hashes[j]); - wd.row_ptrs.push_back(partition_column_row[part_num].data.data() + wd.partition_row_sizes[part_num]); + wd.row_ptrs.push_back(partition_row_container[part_num].data.data() + wd.partition_row_sizes[part_num]); auto & ptr = wd.row_ptrs.back(); assert((reinterpret_cast(ptr) & (ROW_ALIGN - 1)) == 0); wd.partition_row_sizes[part_num] += wd.row_sizes[j]; - partition_column_row[part_num].offsets.push_back(wd.partition_row_sizes[part_num]); + partition_row_container[part_num].offsets.push_back(wd.partition_row_sizes[part_num]); unalignedStore(ptr, nullptr); ptr += sizeof(RowPtr); @@ -191,14 +217,15 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( } else { - partition_column_row[part_num].hashes.push_back(wd.hashes[j]); + partition_row_container[part_num].hashes.push_back(wd.hashes[j]); } const auto & key = key_getter.getJoinKeyWithBuffer(j); key_getter.serializeJoinKey(key, ptr); ptr += key_getter.getJoinKeyByteSize(key); } - for (const auto & [index, _] : row_layout.other_column_indexes) + for (size_t i = 0; i < other_column_count; ++i) { + size_t index = row_layout.other_column_indexes[i].first; if constexpr (has_null_map && !need_record_null_rows) block.getByPosition(index).column->serializeToPos(wd.row_ptrs, start, end - start, true); else @@ -206,16 +233,42 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( } } + if (isRightSemiFamily(kind)) + { + IColumn::ScatterColumns scatter_columns(part_count); + for (size_t i = row_layout.other_column_count_for_other_condition; i < row_layout.other_column_indexes.size(); + ++i) + { + size_t index = row_layout.other_column_indexes[i].first; + auto & column_data = block.getByPosition(index); + size_t column_memory = column_data.column->byteSize(); + for (size_t j = 0; j < part_count; ++j) + { + auto new_column_data = column_data.cloneEmpty(); + if (wd.partition_row_count[j] > 0) + { + size_t memory_hint = column_memory * wd.partition_row_count[j] / rows + 16; + new_column_data.column->assumeMutable()->reserveWithTotalMemoryHint( + wd.partition_row_count[j], + memory_hint); + } + scatter_columns[j] = new_column_data.column->assumeMutable(); + partition_row_container[j].other_column_block.insert(std::move(new_column_data)); + } + column_data.column->scatterTo(scatter_columns, wd.right_semi_selector); + } + } + for (size_t i = 0; i < part_count; ++i) { if (wd.partition_row_count[i] > 0) - multi_row_containers[i]->insert(std::move(partition_column_row[i]), wd.partition_row_count[i]); + multi_row_containers[i]->insert(std::move(partition_row_container[i]), wd.partition_row_count[i]); } } template void insertBlockToRowContainersType( - bool need_record_null_rows, + ASTTableJoin::Kind kind, Block & block, size_t rows, const ColumnRawPtrs & key_columns, @@ -227,6 +280,7 @@ void insertBlockToRowContainersType( { #define CALL(has_null_map, need_record_null_rows) \ insertBlockToRowContainersTypeImpl( \ + kind, \ block, \ rows, \ key_columns, \ @@ -236,6 +290,7 @@ void insertBlockToRowContainersType( worker_data, \ check_lm_row_size); + bool need_record_null_rows = needRecordNotInsertRows(kind); if (null_map) { if (need_record_null_rows) @@ -256,7 +311,7 @@ void insertBlockToRowContainersType( void insertBlockToRowContainers( HashJoinKeyMethod method, - bool need_record_null_rows, + ASTTableJoin::Kind kind, Block & block, size_t rows, const ColumnRawPtrs & key_columns, @@ -272,7 +327,7 @@ void insertBlockToRowContainers( case HashJoinKeyMethod::METHOD: \ using KeyGetterType##METHOD = HashJoinKeyGetterForType; \ insertBlockToRowContainersType( \ - need_record_null_rows, \ + kind, \ block, \ rows, \ key_columns, \ diff --git a/dbms/src/Interpreters/JoinV2/HashJoinBuild.h b/dbms/src/Interpreters/JoinV2/HashJoinBuild.h index 906ca0841a6..294ef6522f2 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinBuild.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinBuild.h @@ -19,6 +19,8 @@ #include #include +#include "Parsers/ASTTablesInSelectQuery.h" + namespace DB { @@ -43,7 +45,7 @@ inline size_t getJoinBuildPartitionNum(HashValueType hash) struct alignas(CPU_CACHE_LINE_SIZE) JoinBuildWorkerData { std::unique_ptr> key_getter; - /// Count of not-null rows + /// Count of non-null-key rows size_t row_count = 0; RowPtr null_rows_list_head = nullptr; @@ -51,6 +53,7 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinBuildWorkerData PaddedPODArray row_sizes; PaddedPODArray hashes; RowPtrs row_ptrs; + IColumn::Selector right_semi_selector; PaddedPODArray partition_row_sizes; PaddedPODArray partition_row_count; @@ -61,7 +64,7 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinBuildWorkerData size_t build_pointer_table_time = 0; size_t build_pointer_table_size = 0; - ssize_t build_pointer_table_iter = -1; + ssize_t current_build_table_index = -1; size_t padding_size = 0; size_t all_size = 0; @@ -75,7 +78,7 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinBuildWorkerData void insertBlockToRowContainers( HashJoinKeyMethod method, - bool need_record_null_rows, + ASTTableJoin::Kind kind, Block & block, size_t rows, const ColumnRawPtrs & key_columns, diff --git a/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp b/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp index 1f785793ba0..985ec40791b 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp @@ -88,11 +88,11 @@ bool HashJoinPointerTable::buildImpl( Stopwatch watch; size_t build_size = 0; bool is_end = false; - while (true) + do { RowContainer * container = nullptr; - if (wd.build_pointer_table_iter != -1) - container = multi_row_containers[wd.build_pointer_table_iter]->getNext(); + if (wd.current_build_table_index != -1) + container = multi_row_containers[wd.current_build_table_index]->getNext(); if (container == nullptr) { { @@ -103,7 +103,7 @@ bool HashJoinPointerTable::buildImpl( container = multi_row_containers[build_table_index]->getNext(); if (container != nullptr) { - wd.build_pointer_table_iter = build_table_index; + wd.current_build_table_index = build_table_index; build_table_index = (build_table_index + 1) % JOIN_BUILD_PARTITION_COUNT; break; } @@ -116,9 +116,9 @@ bool HashJoinPointerTable::buildImpl( break; } } - size_t size = container->size(); - build_size += size; - for (size_t i = 0; i < size; ++i) + size_t rows = container->size(); + build_size += rows; + for (size_t i = 0; i < rows; ++i) { RowPtr row_ptr = container->getRowPtr(i); assert((reinterpret_cast(row_ptr) & (ROW_ALIGN - 1)) == 0); @@ -149,10 +149,7 @@ bool HashJoinPointerTable::buildImpl( if (old_head != nullptr) unalignedStore(row_ptr, old_head); } - - if (build_size >= max_build_size) - break; - } + } while (build_size < max_build_size); wd.build_pointer_table_size += build_size; wd.build_pointer_table_time += watch.elapsedMilliseconds(); return is_end; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index b3c71247b9e..45c6de18d19 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -156,8 +156,10 @@ void JoinProbeHelperUtil::flushInsertBatch(JoinProbeWorkerData & wd, MutableColu { IColumn * column = added_columns[column_index].get(); if (is_nullable) - column = &static_cast(*added_columns[column_index]).getNestedColumn(); + column = &static_cast(*column).getNestedColumn(); column->deserializeAndInsertFromPos(wd.insert_batch, true); + if constexpr (last_flush) + column->flushNTAlignBuffer(); } size_t add_size; @@ -169,42 +171,22 @@ void JoinProbeHelperUtil::flushInsertBatch(JoinProbeWorkerData & wd, MutableColu { size_t column_index = row_layout.other_column_indexes[i].first; added_columns[column_index]->deserializeAndInsertFromPos(wd.insert_batch, true); + if constexpr (last_flush) + added_columns[column_index]->flushNTAlignBuffer(); } if constexpr (late_materialization) wd.row_ptrs_for_lm.insert(wd.insert_batch.begin(), wd.insert_batch.end()); - if constexpr (last_flush) - { - for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) - { - IColumn * column = added_columns[column_index].get(); - if (is_nullable) - column = &static_cast(*added_columns[column_index]).getNestedColumn(); - column->flushNTAlignBuffer(); - } - - size_t add_size; - if constexpr (late_materialization) - add_size = row_layout.other_column_count_for_other_condition; - else - add_size = row_layout.other_column_indexes.size(); - for (size_t i = 0; i < add_size; ++i) - { - size_t column_index = row_layout.other_column_indexes[i].first; - added_columns[column_index]->flushNTAlignBuffer(); - } - } - wd.insert_batch.clear(); } -template void JoinProbeHelperUtil::fillNullMapWithZero(MutableColumns & added_columns) const { for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) { if (is_nullable) { + RUNTIME_CHECK(added_columns[column_index]->isColumnNullable()); auto & nullable_column = static_cast(*added_columns[column_index]); size_t data_size = nullable_column.getNestedColumn().size(); size_t nullmap_size = nullable_column.getNullMapColumn().size(); @@ -246,7 +228,7 @@ struct JoinProbeAdder static void flush(JoinProbeHelper & helper, JoinProbeWorkerData & wd, MutableColumns & added_columns) { helper.flushInsertBatch(wd, added_columns); - helper.fillNullMapWithZero(added_columns); + helper.fillNullMapWithZero(added_columns); } }; @@ -292,7 +274,7 @@ struct JoinProbeAdder static void flush(JoinProbeHelper & helper, JoinProbeWorkerData & wd, MutableColumns & added_columns) { helper.flushInsertBatch(wd, added_columns); - helper.fillNullMapWithZero(added_columns); + helper.fillNullMapWithZero(added_columns); if constexpr (!has_other_condition) { @@ -436,12 +418,120 @@ struct JoinProbeAdder static void flush(JoinProbeHelper &, JoinProbeWorkerData &, MutableColumns &) {} }; +template +struct JoinProbeAdder +{ + static constexpr bool need_matched = true; + static constexpr bool need_not_matched = false; + static constexpr bool break_on_first_match = false; + + static bool ALWAYS_INLINE addMatched( + JoinProbeHelper & helper, + JoinProbeContext &, + JoinProbeWorkerData & wd, + MutableColumns & added_columns, + size_t idx, + size_t & current_offset, + RowPtr row_ptr, + size_t ptr_offset) + { + wd.right_join_row_ptrs.push_back(hasRowPtrMatchedFlag(row_ptr) ? nullptr : row_ptr); + ++current_offset; + wd.selective_offsets.push_back(idx); + helper.insertRowToBatch(wd, added_columns, row_ptr + ptr_offset); + return current_offset >= helper.settings.max_block_size; + } + + static bool ALWAYS_INLINE + addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, size_t, size_t &) + { + return false; + } + + static void flush(JoinProbeHelper & helper, JoinProbeWorkerData & wd, MutableColumns & added_columns) + { + helper.flushInsertBatch(wd, added_columns); + helper.fillNullMapWithZero(added_columns); + } +}; + +template + requires(kind == RightSemi || kind == RightAnti) +struct JoinProbeAdder +{ + static constexpr bool need_matched = true; + static constexpr bool need_not_matched = false; + static constexpr bool break_on_first_match = false; + + static bool ALWAYS_INLINE addMatched( + JoinProbeHelper &, + JoinProbeContext &, + JoinProbeWorkerData &, + MutableColumns &, + size_t, + size_t &, + RowPtr row_ptr, + size_t) + { + setRowPtrMatchedFlag(row_ptr); + return false; + } + + static bool ALWAYS_INLINE + addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, size_t, size_t &) + { + return false; + } + + static void flush(JoinProbeHelper &, JoinProbeWorkerData &, MutableColumns &) {} +}; + +template + requires(kind == RightSemi || kind == RightAnti) +struct JoinProbeAdder +{ + static constexpr bool need_matched = true; + static constexpr bool need_not_matched = false; + static constexpr bool break_on_first_match = false; + + static bool ALWAYS_INLINE addMatched( + JoinProbeHelper & helper, + JoinProbeContext &, + JoinProbeWorkerData & wd, + MutableColumns & added_columns, + size_t idx, + size_t & current_offset, + RowPtr row_ptr, + size_t ptr_offset) + { + if (hasRowPtrMatchedFlag(row_ptr)) + return false; + ++current_offset; + wd.selective_offsets.push_back(idx); + wd.right_join_row_ptrs.push_back(row_ptr); + helper.insertRowToBatch(wd, added_columns, row_ptr + ptr_offset); + return current_offset >= helper.settings.max_block_size; + } + + static bool ALWAYS_INLINE + addNotMatched(JoinProbeHelper &, JoinProbeContext &, JoinProbeWorkerData &, size_t, size_t &) + { + return false; + } + + static void flush(JoinProbeHelper & helper, JoinProbeWorkerData & wd, MutableColumns & added_columns) + { + helper.flushInsertBatch(wd, added_columns); + helper.fillNullMapWithZero(added_columns); + } +}; + JoinProbeHelper::JoinProbeHelper(const HashJoin * join, bool late_materialization) : JoinProbeHelperUtil(join->settings, join->row_layout) , join(join) , pointer_table(join->pointer_table) { -#define CALL3(KeyGetter, JoinType, has_other_condition, late_materialization, tagged_pointer) \ +#define SET_FUNC_PTR(KeyGetter, JoinType, has_other_condition, late_materialization, tagged_pointer) \ { \ func_ptr_has_null \ = &JoinProbeHelper:: \ @@ -451,12 +541,12 @@ JoinProbeHelper::JoinProbeHelper(const HashJoin * join, bool late_materializatio probeImpl; \ } -#define CALL2(KeyGetter, JoinType, has_other_condition, late_materialization) \ - { \ - if (pointer_table.enableTaggedPointer()) \ - CALL3(KeyGetter, JoinType, has_other_condition, late_materialization, true) \ - else \ - CALL3(KeyGetter, JoinType, has_other_condition, late_materialization, false) \ +#define CALL2(KeyGetter, JoinType, has_other_condition, late_materialization) \ + { \ + if (pointer_table.enableTaggedPointer()) \ + SET_FUNC_PTR(KeyGetter, JoinType, has_other_condition, late_materialization, true) \ + else \ + SET_FUNC_PTR(KeyGetter, JoinType, has_other_condition, late_materialization, false) \ } #define CALL1(KeyGetter, JoinType) \ @@ -480,6 +570,8 @@ JoinProbeHelper::JoinProbeHelper(const HashJoin * join, bool late_materializatio CALL1(KeyGetter, Inner) \ else if (kind == LeftOuter) \ CALL1(KeyGetter, LeftOuter) \ + else if (kind == RightOuter) \ + CALL1(KeyGetter, RightOuter) \ else if (kind == Semi && !has_other_condition) \ CALL2(KeyGetter, Semi, false, false) \ else if (kind == Anti && !has_other_condition) \ @@ -488,6 +580,14 @@ JoinProbeHelper::JoinProbeHelper(const HashJoin * join, bool late_materializatio CALL2(KeyGetter, LeftOuterSemi, false, false) \ else if (kind == LeftOuterAnti && !has_other_condition) \ CALL2(KeyGetter, LeftOuterAnti, false, false) \ + else if (kind == RightSemi && has_other_condition) \ + CALL2(KeyGetter, RightSemi, true, false) \ + else if (kind == RightSemi && !has_other_condition) \ + CALL2(KeyGetter, RightSemi, false, false) \ + else if (kind == RightAnti && has_other_condition) \ + CALL2(KeyGetter, RightAnti, true, false) \ + else if (kind == RightAnti && !has_other_condition) \ + CALL2(KeyGetter, RightAnti, false, false) \ else \ throw Exception( \ fmt::format("Logical error: unknown combination of JOIN {}", magic_enum::enum_name(kind)), \ @@ -509,10 +609,11 @@ JoinProbeHelper::JoinProbeHelper(const HashJoin * join, bool late_materializatio fmt::format("Unknown JOIN keys variant {}.", magic_enum::enum_name(join->method)), ErrorCodes::UNKNOWN_SET_DATA_VARIANT); } + #undef CALL #undef CALL1 #undef CALL2 -#undef CALL3 +#undef SET_FUNC_PTR } Block JoinProbeHelper::probe(JoinProbeContext & ctx, JoinProbeWorkerData & wd) @@ -556,6 +657,11 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData & w wd.row_ptrs_for_lm.clear(); wd.row_ptrs_for_lm.reserve(settings.max_block_size); } + if constexpr (kind == RightSemi || kind == RightAnti) + { + wd.right_join_row_ptrs.clear(); + wd.right_join_row_ptrs.reserve(settings.max_block_size); + } size_t left_columns = join->left_sample_block_pruned.columns(); size_t right_columns = join->right_sample_block_pruned.columns(); @@ -755,13 +861,13 @@ void JoinProbeHelper::probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerDa } } - ptr = getNextRowPtr(ptr); + ptr = getNextRowPtr(ptr); if (ptr == nullptr) break; } if unlikely (ptr != nullptr) { - ptr = getNextRowPtr(ptr); + ptr = getNextRowPtr(ptr); if (ptr == nullptr) ++idx; break; @@ -864,7 +970,7 @@ void JoinProbeHelper::probeFillColumnsPrefetch( if (state->stage == ProbePrefetchStage::FindNext) { RowPtr ptr = state->ptr; - RowPtr next_ptr = getNextRowPtr(ptr); + RowPtr next_ptr = getNextRowPtr(ptr); state->ptr = next_ptr; const auto & key2 = key_getter.deserializeJoinKey(ptr + key_offset); @@ -1085,6 +1191,29 @@ Block JoinProbeHelper::handleOtherConditions( ctx.rows_not_matched[idx] &= !is_matched; } } + else if (kind == RightOuter) + { + RUNTIME_CHECK(wd.right_join_row_ptrs.size() == rows); + RUNTIME_CHECK(wd.filter.size() == rows); + for (size_t i = 0; i < rows; ++i) + { + bool is_matched = wd.filter[i]; + if (is_matched && wd.right_join_row_ptrs[i]) + setRowPtrMatchedFlag(wd.right_join_row_ptrs[i]); + } + } + else if (isRightSemiFamily(kind)) + { + RUNTIME_CHECK(wd.right_join_row_ptrs.size() == rows); + RUNTIME_CHECK(wd.filter.size() == rows); + for (size_t i = 0; i < rows; ++i) + { + bool is_matched = wd.filter[i]; + if (is_matched) + setRowPtrMatchedFlag(wd.right_join_row_ptrs[i]); + } + return output_block_after_finalize; + } join->initOutputBlock(wd.result_block_for_other_condition); diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 698f1dba409..0f0de74539e 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -100,12 +100,23 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData /// For late materialization RowPtrs row_ptrs_for_lm; RowPtrs filter_row_ptrs_for_lm; + /// For right outer/semi/anti join with other conditions + RowPtrs right_join_row_ptrs; /// Schema: HashJoin::all_sample_block_pruned Block result_block; /// Schema: HashJoin::output_block_after_finalize Block result_block_for_other_condition; + /// Scan build side + ssize_t current_scan_table_index = -1; + RowContainer * current_container = nullptr; + size_t current_container_index = 0; + /// Schema: HashJoin::output_block_after_finalize + Block scan_result_block; + size_t current_scan_block_rows = 0; + bool is_scan_end = false; + /// Metrics size_t probe_handle_rows = 0; size_t probe_time = 0; @@ -113,6 +124,7 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData size_t replicate_time = 0; size_t other_condition_time = 0; size_t collision = 0; + size_t scan_build_side_time = 0; }; class JoinProbeHelperUtil @@ -151,7 +163,6 @@ class JoinProbeHelperUtil template void flushInsertBatch(JoinProbeWorkerData & wd, MutableColumns & added_columns) const; - template void fillNullMapWithZero(MutableColumns & added_columns) const; protected: diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp new file mode 100644 index 00000000000..743f83d0ec9 --- /dev/null +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp @@ -0,0 +1,331 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +namespace DB +{ + +using enum ASTTableJoin::Kind; + +JoinProbeBuildScanner::JoinProbeBuildScanner(const HashJoin * join) + : join(join) +{ + join_key_getter = createHashJoinKeyGetter(join->method, join->collators); + ColumnRawPtrs key_columns = extractAndMaterializeKeyColumns( + join->right_sample_block_pruned, + materialized_key_columns, + join->key_names_right); + ColumnPtr null_map_holder; + ConstNullMapPtr null_map{}; + extractNestedColumnsAndNullMap(key_columns, null_map_holder, null_map); + resetHashJoinKeyGetter(join->method, join_key_getter, key_columns, join->row_layout); + + size_t left_columns = join->left_sample_block_pruned.columns(); + + auto kind = join->kind; + bool need_row_data; + bool need_other_block_data; + if (kind == RightOuter) + { + need_row_data = true; + need_other_block_data = false; + } + else + { + need_row_data = false; + for (auto [column_index, is_nullable] : join->row_layout.raw_key_column_indexes) + { + auto output_index = join->output_column_indexes.at(left_columns + column_index); + need_row_data |= output_index >= 0; + if (need_row_data) + break; + } + for (auto [column_index, _] : join->row_layout.other_column_indexes) + { + auto output_index = join->output_column_indexes.at(left_columns + column_index); + need_row_data |= output_index >= 0; + if (need_row_data) + break; + } + + need_other_block_data = (kind == RightSemi || kind == RightAnti) + && join->row_layout.other_column_indexes.size() > join->row_layout.other_column_count_for_other_condition; + + // The output data should not be empty + RUNTIME_CHECK(need_row_data || need_other_block_data); + } + +#define SET_FUNC_PTR(KeyGetter, JoinType, need_row_data, need_other_block_data) \ + { \ + scan_func_ptr = &JoinProbeBuildScanner::scanImpl; \ + } + +#define CALL2(KeyGetter, JoinType) \ + { \ + if (need_row_data && need_other_block_data) \ + SET_FUNC_PTR(KeyGetter, JoinType, true, true) \ + else if (need_row_data) \ + SET_FUNC_PTR(KeyGetter, JoinType, true, false) \ + else \ + SET_FUNC_PTR(KeyGetter, JoinType, false, true) \ + } + +#define CALL(KeyGetter) \ + { \ + if (kind == RightOuter) \ + SET_FUNC_PTR(KeyGetter, RightOuter, true, false) \ + else if (kind == RightSemi) \ + CALL2(KeyGetter, RightSemi) \ + else if (kind == RightAnti) \ + CALL2(KeyGetter, RightAnti) \ + else \ + throw Exception( \ + fmt::format( \ + "Logical error: unknown combination of JOIN {} during scanning build side", \ + magic_enum::enum_name(kind)), \ + ErrorCodes::LOGICAL_ERROR); \ + } + + switch (join->method) + { +#define M(METHOD) \ + case HashJoinKeyMethod::METHOD: \ + using KeyGetterType##METHOD = HashJoinKeyGetterForType; \ + CALL(KeyGetterType##METHOD); \ + break; + APPLY_FOR_HASH_JOIN_VARIANTS(M) +#undef M + + default: + throw Exception( + fmt::format("Unknown JOIN keys variant {} during scanning build side", magic_enum::enum_name(join->method)), + ErrorCodes::UNKNOWN_SET_DATA_VARIANT); + } + +#undef CALL +#undef CALL2 +#undef SET_FUNC_PTR +} + +Block JoinProbeBuildScanner::scan(JoinProbeWorkerData & wd) +{ + return (this->*scan_func_ptr)(wd); +} + +template +Block JoinProbeBuildScanner::scanImpl(JoinProbeWorkerData & wd) +{ + if (wd.is_scan_end) + return {}; + + using KeyGetterType = typename KeyGetter::Type; + using HashValueType = typename KeyGetter::HashValueType; + const auto & multi_row_containers = join->multi_row_containers; + const size_t max_block_size = join->settings.max_block_size; + const size_t left_columns = join->left_sample_block_pruned.columns(); + + auto & key_getter = *static_cast(join_key_getter.get()); + constexpr size_t key_offset + = sizeof(RowPtr) + (KeyGetterType::joinKeyCompareHashFirst() ? sizeof(HashValueType) : 0); + + size_t scan_size = 0; + RowContainer * container = wd.current_container; + size_t index = wd.current_container_index; + size_t scan_block_rows = wd.current_scan_block_rows; + wd.selective_offsets.clear(); + wd.selective_offsets.reserve(max_block_size); + constexpr size_t insert_batch_max_size = 256; + wd.insert_batch.clear(); + wd.insert_batch.reserve(insert_batch_max_size); + join->initOutputBlock(wd.scan_result_block); + do + { + if (container == nullptr) + { + if (wd.current_scan_table_index != -1) + container = multi_row_containers[wd.current_scan_table_index]->getScanNext(); + if (container == nullptr) + { + { + std::unique_lock lock(scan_build_lock); + for (size_t i = 0; i < JOIN_BUILD_PARTITION_COUNT; ++i) + { + scan_build_index = (scan_build_index + i) % JOIN_BUILD_PARTITION_COUNT; + container = multi_row_containers[scan_build_index]->getScanNext(); + if (container != nullptr) + { + wd.current_scan_table_index = scan_build_index; + scan_build_index = (scan_build_index + 1) % JOIN_BUILD_PARTITION_COUNT; + index = 0; + break; + } + } + } + } + if unlikely (container == nullptr) + { + wd.is_scan_end = true; + break; + } + } + size_t rows = container->size(); + while (index < rows) + { + RowPtr ptr = container->getRowPtr(index); + bool need_output; + if constexpr (kind == RightSemi) + need_output = hasRowPtrMatchedFlag(ptr); + else + need_output = !hasRowPtrMatchedFlag(ptr); + if (need_output) + { + if constexpr (need_row_data) + { + const auto & key = key_getter.deserializeJoinKey(ptr + key_offset); + size_t required_offset = key_offset + key_getter.getRequiredKeyOffset(key); + wd.insert_batch.push_back(ptr + required_offset); + if unlikely (wd.insert_batch.size() >= insert_batch_max_size) + flushInsertBatch(wd); + } + if constexpr (need_other_block_data) + { + wd.selective_offsets.push_back(index); + } + ++scan_block_rows; + if unlikely (scan_block_rows >= max_block_size) + { + ++index; + break; + } + } + ++index; + } + + if constexpr (need_other_block_data) + { + size_t other_columns = join->row_layout.other_column_indexes.size() + - join->row_layout.other_column_count_for_other_condition; + RUNTIME_CHECK(container->other_column_block.columns() == other_columns); + for (size_t i = 0; i < other_columns; ++i) + { + size_t column_index + = join->row_layout.other_column_indexes[join->row_layout.other_column_count_for_other_condition + i] + .first; + auto output_index = join->output_column_indexes.at(left_columns + column_index); + // These columns must be in the final output schema otherwise they should be pruned + RUNTIME_CHECK(output_index >= 0); + auto & src_column = container->other_column_block.safeGetByPosition(i); + auto & des_column = wd.scan_result_block.safeGetByPosition(output_index); + des_column.column->assumeMutable()->insertSelectiveFrom(*src_column.column, wd.selective_offsets); + } + wd.selective_offsets.clear(); + } + + if (index >= rows) + container = nullptr; + + if unlikely (scan_block_rows >= max_block_size) + break; + } while (scan_size < 2 * max_block_size); + + flushInsertBatch(wd); + fillNullMapWithZero(wd); + + wd.current_container = container; + wd.current_container_index = index; + wd.current_scan_block_rows = scan_block_rows; + + if unlikely (wd.is_scan_end || wd.scan_result_block.rows() >= max_block_size) + { + if (wd.scan_result_block.rows() == 0) + return {}; + + wd.current_scan_block_rows = 0; + Block res_block; + res_block.swap(wd.scan_result_block); + return res_block; + } + return join->output_block_after_finalize; +} + +template +void JoinProbeBuildScanner::flushInsertBatch(JoinProbeWorkerData & wd) const +{ + const size_t left_columns = join->left_sample_block_pruned.columns(); + for (auto [column_index, is_nullable] : join->row_layout.raw_key_column_indexes) + { + auto output_index = join->output_column_indexes.at(left_columns + column_index); + if (output_index < 0) + { + join->right_sample_block_pruned.safeGetByPosition(column_index) + .column->deserializeAndAdvancePos(wd.insert_batch); + continue; + } + auto & des_column = wd.scan_result_block.safeGetByPosition(output_index); + IColumn * column = des_column.column->assumeMutable().get(); + if (is_nullable) + column = &static_cast(*column).getNestedColumn(); + column->deserializeAndInsertFromPos(wd.insert_batch, true); + if constexpr (last_flush) + column->flushNTAlignBuffer(); + } + + size_t other_column_count; + if (join->kind == RightOuter) + other_column_count = join->row_layout.other_column_indexes.size(); + else + other_column_count = join->row_layout.other_column_count_for_other_condition; + for (size_t i = 0; i < other_column_count; ++i) + { + size_t column_index = join->row_layout.other_column_indexes[i].first; + auto output_index = join->output_column_indexes.at(left_columns + column_index); + if (output_index < 0) + { + join->right_sample_block_pruned.safeGetByPosition(column_index) + .column->deserializeAndAdvancePos(wd.insert_batch); + continue; + } + auto & des_column = wd.scan_result_block.safeGetByPosition(output_index); + des_column.column->assumeMutable()->deserializeAndInsertFromPos(wd.insert_batch, true); + if constexpr (last_flush) + des_column.column->assumeMutable()->flushNTAlignBuffer(); + } + + wd.insert_batch.clear(); +} + +void JoinProbeBuildScanner::fillNullMapWithZero(JoinProbeWorkerData & wd) const +{ + size_t left_columns = join->left_sample_block_pruned.columns(); + for (auto [column_index, is_nullable] : join->row_layout.raw_key_column_indexes) + { + auto output_index = join->output_column_indexes.at(left_columns + column_index); + if (!is_nullable || output_index < 0) + continue; + + auto des_mut_column = wd.scan_result_block.safeGetByPosition(output_index).column->assumeMutable(); + RUNTIME_CHECK(des_mut_column->isColumnNullable()); + auto & nullable_column = static_cast(*des_mut_column); + size_t data_size = nullable_column.getNestedColumn().size(); + size_t nullmap_size = nullable_column.getNullMapColumn().size(); + RUNTIME_CHECK(nullmap_size <= data_size); + nullable_column.getNullMapColumn().getData().resize_fill_zero(data_size); + } +} + +} // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.h b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.h new file mode 100644 index 00000000000..0a202d44433 --- /dev/null +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.h @@ -0,0 +1,68 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace DB +{ + +class HashJoin; +class JoinProbeBuildScanner +{ +public: + explicit JoinProbeBuildScanner(const HashJoin * join); + + Block scan(JoinProbeWorkerData & wd); + +private: + template + Block scanImpl(JoinProbeWorkerData & wd); + + template + void ALWAYS_INLINE + insertRowToBatch(JoinProbeWorkerData & wd, RowPtr row_ptr, size_t index, size_t insert_batch_max_size) const + { + if constexpr (need_row_data) + { + wd.insert_batch.push_back(row_ptr); + if unlikely (wd.insert_batch.size() >= insert_batch_max_size) + flushInsertBatch(wd); + } + if constexpr (need_other_block_data) + { + wd.selective_offsets.push_back(index); + } + } + + template + void flushInsertBatch(JoinProbeWorkerData & wd) const; + + void fillNullMapWithZero(JoinProbeWorkerData & wd) const; + +private: + using FuncType = Block (JoinProbeBuildScanner::*)(JoinProbeWorkerData &); + FuncType scan_func_ptr = nullptr; + + const HashJoin * join; + std::mutex scan_build_lock; + size_t scan_build_index = 0; + /// Used for deserializing join key and getting required key offset + std::unique_ptr> join_key_getter; + Columns materialized_key_columns; +}; + + +} // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h b/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h index 16d1c640082..b529ecb9e78 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h @@ -15,9 +15,13 @@ #pragma once #include +#include +#include #include #include +#include +#include #include namespace DB @@ -41,12 +45,16 @@ struct HashJoinRowLayout size_t key_column_fixed_size = 0; size_t other_column_fixed_size = 0; + size_t other_column_for_other_condition_fixed_size = 0; }; using RowPtr = char *; using RowPtrs = PaddedPODArray; -constexpr size_t ROW_ALIGN = 4; +static_assert(alignof(std::atomic) == alignof(uintptr_t)); +constexpr UInt8 ROW_ALIGN = alignof(uintptr_t); +static_assert((ROW_ALIGN & (ROW_ALIGN - 1)) == 0); +static_assert(ROW_ALIGN >= 4 && ROW_ALIGN < UINT8_MAX); constexpr size_t ROW_PTR_TAG_BITS = 16; constexpr size_t ROW_PTR_TAG_MASK = (1 << ROW_PTR_TAG_BITS) - 1; @@ -54,9 +62,46 @@ constexpr size_t ROW_PTR_TAG_SHIFT = 8 * sizeof(RowPtr) - ROW_PTR_TAG_BITS; static_assert(sizeof(RowPtr) == sizeof(uintptr_t)); static_assert(sizeof(RowPtr) == 8); +template inline RowPtr getNextRowPtr(RowPtr ptr) { - return unalignedLoad(ptr); + using enum ASTTableJoin::Kind; + if constexpr (kind == RightOuter || kind == RightSemi || kind == RightAnti) + { + auto next = reinterpret_cast *>(ptr)->load(std::memory_order_relaxed); + return reinterpret_cast(next & (~static_cast(ROW_ALIGN - 1))); + } + return *reinterpret_cast(ptr); +} + +inline UInt8 getRowPtrFlag(RowPtr ptr) +{ + return reinterpret_cast *>(ptr)->load(std::memory_order_relaxed) + & static_cast(ROW_ALIGN - 1); +} + +inline bool hasRowPtrMatchedFlag(RowPtr ptr) +{ + return getRowPtrFlag(ptr) & 0x01; +} + +inline void setRowPtrMatchedFlag(RowPtr ptr) +{ + if (hasRowPtrMatchedFlag(ptr)) + return; + reinterpret_cast *>(ptr)->fetch_or(0x01, std::memory_order_relaxed); +} + +inline bool hasRowPtrNullFlag(RowPtr ptr) +{ + return getRowPtrFlag(ptr) & 0x10; +} + +inline void setRowPtrNullFlag(RowPtr ptr) +{ + if (hasRowPtrNullFlag(ptr)) + return; + reinterpret_cast *>(ptr)->fetch_or(0x10, std::memory_order_relaxed); } inline UInt16 getRowPtrTag(RowPtr ptr) @@ -88,6 +133,8 @@ struct RowContainer PaddedPODArray data; PaddedPODArray offsets; PaddedPODArray hashes; + /// Used for right semi/anti join + Block other_column_block; size_t size() const { return offsets.size(); } diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp index 02f9cb8f797..03687c964f3 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp @@ -269,7 +269,7 @@ SemiJoinProbeHelper::probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerDat SCOPE_EXIT({ flushInsertBatch(wd, added_columns); - fillNullMapWithZero(added_columns); + fillNullMapWithZero(added_columns); wd.collision += collision; }); @@ -294,12 +294,12 @@ SemiJoinProbeHelper::probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerDat insertRowToBatch(wd, added_columns, ptr + key_offset + key_getter.getRequiredKeyOffset(key2)); if unlikely (current_offset >= end_offset) { - ptr = getNextRowPtr(ptr); + ptr = getNextRowPtr(ptr); break; } } - ptr = getNextRowPtr(ptr); + ptr = getNextRowPtr(ptr); } if (prev_offset == current_offset) { @@ -365,12 +365,12 @@ SemiJoinProbeHelper::probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerDat insertRowToBatch(wd, added_columns, ptr + key_offset + key_getter.getRequiredKeyOffset(key2)); if unlikely (current_offset >= end_offset) { - ptr = getNextRowPtr(ptr); + ptr = getNextRowPtr(ptr); break; } } - ptr = getNextRowPtr(ptr); + ptr = getNextRowPtr(ptr); if (ptr == nullptr) break; } @@ -474,7 +474,7 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( if (state->stage == ProbePrefetchStage::FindNext) { RowPtr ptr = state->ptr; - RowPtr next_ptr = getNextRowPtr(ptr); + RowPtr next_ptr = getNextRowPtr(ptr); state->ptr = next_ptr; const auto & key2 = key_getter.deserializeJoinKey(ptr + key_offset); @@ -650,7 +650,7 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( } flushInsertBatch(wd, added_columns); - fillNullMapWithZero(added_columns); + fillNullMapWithZero(added_columns); ctx.current_row_idx = idx; ctx.prefetch_active_states = active_states; diff --git a/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp b/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp index efd8ce06bbc..e3a0850a309 100644 --- a/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp +++ b/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp @@ -50,37 +50,59 @@ void HashJoinV2ProbeTransformOp::operateSuffixImpl() OperatorStatus HashJoinV2ProbeTransformOp::onOutput(Block & block) { - assert(!probe_context.isAllFinished()); - block = join_ptr->probeBlock(probe_context, op_index); - size_t rows = block.rows(); - joined_rows += rows; - return OperatorStatus::HAS_OUTPUT; + while (true) + { + switch (status) + { + case ProbeStatus::PROBE: + if unlikely (probe_context.isAllFinished()) + { + join_ptr->finishOneProbe(op_index); + status = join_ptr->needProbeScanBuildSide() ? ProbeStatus::WAIT_PROBE_FINISH : ProbeStatus::FINISHED; + block = join_ptr->probeLastResultBlock(op_index); + if (block) + return OperatorStatus::HAS_OUTPUT; + break; + } + block = join_ptr->probeBlock(probe_context, op_index); + joined_rows += block.rows(); + return OperatorStatus::HAS_OUTPUT; + case ProbeStatus::WAIT_PROBE_FINISH: + if (join_ptr->isAllProbeFinished()) + { + status = ProbeStatus::SCAN_BUILD_SIDE; + break; + } + return OperatorStatus::WAIT_FOR_NOTIFY; + case ProbeStatus::SCAN_BUILD_SIDE: + block = join_ptr->probeScanBuildSide(op_index); + scan_hash_map_rows += block.rows(); + if unlikely (!block) + status = ProbeStatus::FINISHED; + return OperatorStatus::HAS_OUTPUT; + case ProbeStatus::FINISHED: + block = {}; + return OperatorStatus::HAS_OUTPUT; + } + } } OperatorStatus HashJoinV2ProbeTransformOp::transformImpl(Block & block) { + assert(status == ProbeStatus::PROBE); assert(probe_context.isAllFinished()); - if unlikely (!block) + if likely (block) { - join_ptr->finishOneProbe(op_index); - probe_context.input_is_finished = true; - block = join_ptr->probeLastResultBlock(op_index); - return OperatorStatus::HAS_OUTPUT; + if (block.rows() == 0) + return OperatorStatus::NEED_INPUT; + probe_context.resetBlock(block); } - if (block.rows() == 0) - return OperatorStatus::NEED_INPUT; - probe_context.resetBlock(block); return onOutput(block); } OperatorStatus HashJoinV2ProbeTransformOp::tryOutputImpl(Block & block) { - if unlikely (probe_context.input_is_finished) - { - block = {}; - return OperatorStatus::HAS_OUTPUT; - } - if (probe_context.isAllFinished()) + if (status == ProbeStatus::PROBE && probe_context.isAllFinished()) return OperatorStatus::NEED_INPUT; return onOutput(block); } diff --git a/dbms/src/Operators/HashJoinV2ProbeTransformOp.h b/dbms/src/Operators/HashJoinV2ProbeTransformOp.h index b78872a6ae6..6e19e6e0fc1 100644 --- a/dbms/src/Operators/HashJoinV2ProbeTransformOp.h +++ b/dbms/src/Operators/HashJoinV2ProbeTransformOp.h @@ -51,5 +51,14 @@ class HashJoinV2ProbeTransformOp : public TransformOp size_t joined_rows = 0; size_t scan_hash_map_rows = 0; + + enum class ProbeStatus + { + PROBE, + WAIT_PROBE_FINISH, + SCAN_BUILD_SIDE, + FINISHED, + }; + ProbeStatus status = ProbeStatus::PROBE; }; } // namespace DB From d44c1867a08ed87f6390a359032f152d23600bf1 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Fri, 6 Jun 2025 11:57:53 +0800 Subject: [PATCH 68/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp index 743f83d0ec9..a4ca34ed8b0 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp @@ -171,7 +171,6 @@ Block JoinProbeBuildScanner::scanImpl(JoinProbeWorkerData & wd) { wd.current_scan_table_index = scan_build_index; scan_build_index = (scan_build_index + 1) % JOIN_BUILD_PARTITION_COUNT; - index = 0; break; } } @@ -237,7 +236,10 @@ Block JoinProbeBuildScanner::scanImpl(JoinProbeWorkerData & wd) } if (index >= rows) + { container = nullptr; + index = 0; + } if unlikely (scan_block_rows >= max_block_size) break; From f4b87a85f5223a77dbb41dbcfafc27f0006cbb70 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Fri, 6 Jun 2025 15:48:57 +0800 Subject: [PATCH 69/84] u Signed-off-by: gengliqi --- .../Planner/Plans/PhysicalJoinV2Probe.cpp | 1 + dbms/src/Interpreters/JoinV2/HashJoin.h | 2 ++ .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 22 ++++++++-------- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 6 ++--- .../src/Interpreters/JoinV2/SemiJoinProbe.cpp | 26 ++++++++++++------- 5 files changed, 33 insertions(+), 24 deletions(-) diff --git a/dbms/src/Flash/Planner/Plans/PhysicalJoinV2Probe.cpp b/dbms/src/Flash/Planner/Plans/PhysicalJoinV2Probe.cpp index 73292e8a27b..2dcb2f93e83 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalJoinV2Probe.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalJoinV2Probe.cpp @@ -36,6 +36,7 @@ void PhysicalJoinV2Probe::buildPipelineExecGroupImpl( builder.appendTransformOp( std::make_unique(exec_context, log->identifier(), join_ptr, probe_index++)); }); + exec_context.addOneTimeFuture(join_ptr->getWaitProbeFinishFuture()); join_ptr.reset(); } } // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.h b/dbms/src/Interpreters/JoinV2/HashJoin.h index a23967fbace..b9610c1dc75 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.h +++ b/dbms/src/Interpreters/JoinV2/HashJoin.h @@ -84,6 +84,8 @@ class HashJoin const JoinProfileInfoPtr & getProfileInfo() const { return profile_info; } + const OneTimeNotifyFuturePtr & getWaitProbeFinishFuture() { return wait_probe_finished_future; } + private: void initRowLayoutAndHashJoinMethod(); diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 45c6de18d19..ecd95d3f89c 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -149,7 +149,7 @@ void JoinProbeContext::prepareForHashProbe( is_prepared = true; } -template +template void JoinProbeHelperUtil::flushInsertBatch(JoinProbeWorkerData & wd, MutableColumns & added_columns) const { for (auto [column_index, is_nullable] : row_layout.raw_key_column_indexes) @@ -163,7 +163,7 @@ void JoinProbeHelperUtil::flushInsertBatch(JoinProbeWorkerData & wd, MutableColu } size_t add_size; - if constexpr (late_materialization) + if constexpr (late_materialization || is_right_semi_join) add_size = row_layout.other_column_count_for_other_condition; else add_size = row_layout.other_column_indexes.size(); @@ -174,7 +174,7 @@ void JoinProbeHelperUtil::flushInsertBatch(JoinProbeWorkerData & wd, MutableColu if constexpr (last_flush) added_columns[column_index]->flushNTAlignBuffer(); } - if constexpr (late_materialization) + if constexpr (late_materialization && !is_right_semi_join) wd.row_ptrs_for_lm.insert(wd.insert_batch.begin(), wd.insert_batch.end()); wd.insert_batch.clear(); @@ -215,7 +215,7 @@ struct JoinProbeAdder { ++current_offset; wd.selective_offsets.push_back(idx); - helper.insertRowToBatch(wd, added_columns, row_ptr + ptr_offset); + helper.insertRowToBatch(wd, added_columns, row_ptr + ptr_offset); return current_offset >= helper.settings.max_block_size; } @@ -227,7 +227,7 @@ struct JoinProbeAdder static void flush(JoinProbeHelper & helper, JoinProbeWorkerData & wd, MutableColumns & added_columns) { - helper.flushInsertBatch(wd, added_columns); + helper.flushInsertBatch(wd, added_columns); helper.fillNullMapWithZero(added_columns); } }; @@ -251,7 +251,7 @@ struct JoinProbeAdder { ++current_offset; wd.selective_offsets.push_back(idx); - helper.insertRowToBatch(wd, added_columns, row_ptr + ptr_offset); + helper.insertRowToBatch(wd, added_columns, row_ptr + ptr_offset); return current_offset >= helper.settings.max_block_size; } @@ -273,7 +273,7 @@ struct JoinProbeAdder static void flush(JoinProbeHelper & helper, JoinProbeWorkerData & wd, MutableColumns & added_columns) { - helper.flushInsertBatch(wd, added_columns); + helper.flushInsertBatch(wd, added_columns); helper.fillNullMapWithZero(added_columns); if constexpr (!has_other_condition) @@ -438,7 +438,7 @@ struct JoinProbeAdder wd.right_join_row_ptrs.push_back(hasRowPtrMatchedFlag(row_ptr) ? nullptr : row_ptr); ++current_offset; wd.selective_offsets.push_back(idx); - helper.insertRowToBatch(wd, added_columns, row_ptr + ptr_offset); + helper.insertRowToBatch(wd, added_columns, row_ptr + ptr_offset); return current_offset >= helper.settings.max_block_size; } @@ -450,7 +450,7 @@ struct JoinProbeAdder static void flush(JoinProbeHelper & helper, JoinProbeWorkerData & wd, MutableColumns & added_columns) { - helper.flushInsertBatch(wd, added_columns); + helper.flushInsertBatch(wd, added_columns); helper.fillNullMapWithZero(added_columns); } }; @@ -509,7 +509,7 @@ struct JoinProbeAdder ++current_offset; wd.selective_offsets.push_back(idx); wd.right_join_row_ptrs.push_back(row_ptr); - helper.insertRowToBatch(wd, added_columns, row_ptr + ptr_offset); + helper.insertRowToBatch(wd, added_columns, row_ptr + ptr_offset); return current_offset >= helper.settings.max_block_size; } @@ -521,7 +521,7 @@ struct JoinProbeAdder static void flush(JoinProbeHelper & helper, JoinProbeWorkerData & wd, MutableColumns & added_columns) { - helper.flushInsertBatch(wd, added_columns); + helper.flushInsertBatch(wd, added_columns); helper.fillNullMapWithZero(added_columns); } }; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 0f0de74539e..f3f4df130b2 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -152,15 +152,15 @@ class JoinProbeHelperUtil return key_getter.joinKeyIsEqual(key1, key2); } - template + template void ALWAYS_INLINE insertRowToBatch(JoinProbeWorkerData & wd, MutableColumns & added_columns, RowPtr row_ptr) const { wd.insert_batch.push_back(row_ptr); if unlikely (wd.insert_batch.size() >= settings.probe_insert_batch_size) - flushInsertBatch(wd, added_columns); + flushInsertBatch(wd, added_columns); } - template + template void flushInsertBatch(JoinProbeWorkerData & wd, MutableColumns & added_columns) const; void fillNullMapWithZero(MutableColumns & added_columns) const; diff --git a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp index 03687c964f3..c5c6b416790 100644 --- a/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/SemiJoinProbe.cpp @@ -261,14 +261,11 @@ SemiJoinProbeHelper::probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerDat RUNTIME_CHECK(probe_list->slotCapacity() == ctx.rows); size_t current_offset = 0; size_t collision = 0; - size_t key_offset = sizeof(RowPtr); - if constexpr (KeyGetterType::joinKeyCompareHashFirst()) - { - key_offset += sizeof(HashValueType); - } + constexpr size_t key_offset + = sizeof(RowPtr) + (KeyGetterType::joinKeyCompareHashFirst() ? sizeof(HashValueType) : 0); SCOPE_EXIT({ - flushInsertBatch(wd, added_columns); + flushInsertBatch(wd, added_columns); fillNullMapWithZero(added_columns); wd.collision += collision; @@ -291,7 +288,10 @@ SemiJoinProbeHelper::probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerDat if (key_is_equal) { wd.selective_offsets.push_back(idx); - insertRowToBatch(wd, added_columns, ptr + key_offset + key_getter.getRequiredKeyOffset(key2)); + insertRowToBatch( + wd, + added_columns, + ptr + key_offset + key_getter.getRequiredKeyOffset(key2)); if unlikely (current_offset >= end_offset) { ptr = getNextRowPtr(ptr); @@ -362,7 +362,10 @@ SemiJoinProbeHelper::probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerDat if (key_is_equal) { wd.selective_offsets.push_back(idx); - insertRowToBatch(wd, added_columns, ptr + key_offset + key_getter.getRequiredKeyOffset(key2)); + insertRowToBatch( + wd, + added_columns, + ptr + key_offset + key_getter.getRequiredKeyOffset(key2)); if unlikely (current_offset >= end_offset) { ptr = getNextRowPtr(ptr); @@ -487,7 +490,10 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( if (key_is_equal) { wd.selective_offsets.push_back(state->index); - insertRowToBatch(wd, added_columns, ptr + key_offset + key_getter.getRequiredKeyOffset(key2)); + insertRowToBatch( + wd, + added_columns, + ptr + key_offset + key_getter.getRequiredKeyOffset(key2)); if unlikely (current_offset >= settings.max_block_size) { probe_list->at(state->index).build_row_ptr = next_ptr; @@ -649,7 +655,7 @@ void NO_INLINE SemiJoinProbeHelper::probeFillColumnsPrefetch( } } - flushInsertBatch(wd, added_columns); + flushInsertBatch(wd, added_columns); fillNullMapWithZero(added_columns); ctx.current_row_idx = idx; From 12dfd490ecaf3c5da65e74b2702388e5279d4331 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Fri, 6 Jun 2025 16:24:02 +0800 Subject: [PATCH 70/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 2 +- dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index f3f4df130b2..16f881af801 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -157,7 +157,7 @@ class JoinProbeHelperUtil { wd.insert_batch.push_back(row_ptr); if unlikely (wd.insert_batch.size() >= settings.probe_insert_batch_size) - flushInsertBatch(wd, added_columns); + flushInsertBatch(wd, added_columns); } template diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp index a4ca34ed8b0..bd0db517147 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp @@ -26,10 +26,8 @@ JoinProbeBuildScanner::JoinProbeBuildScanner(const HashJoin * join) : join(join) { join_key_getter = createHashJoinKeyGetter(join->method, join->collators); - ColumnRawPtrs key_columns = extractAndMaterializeKeyColumns( - join->right_sample_block_pruned, - materialized_key_columns, - join->key_names_right); + ColumnRawPtrs key_columns + = extractAndMaterializeKeyColumns(join->right_sample_block, materialized_key_columns, join->key_names_right); ColumnPtr null_map_holder; ConstNullMapPtr null_map{}; extractNestedColumnsAndNullMap(key_columns, null_map_holder, null_map); From 81eeb908378b596b3856099eda7c1572c2196635 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Fri, 6 Jun 2025 20:19:25 +0800 Subject: [PATCH 71/84] u Signed-off-by: gengliqi --- dbms/CMakeLists.txt | 1 + dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/dbms/CMakeLists.txt b/dbms/CMakeLists.txt index af756df9205..8b432defc38 100644 --- a/dbms/CMakeLists.txt +++ b/dbms/CMakeLists.txt @@ -114,6 +114,7 @@ check_then_add_sources_compile_flag ( src/Interpreters/JoinV2/HashJoinBuild.cpp src/Interpreters/JoinV2/HashJoinProbe.cpp src/Interpreters/JoinV2/SemiJoinProbe.cpp + src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp src/IO/Compression/EncodingUtil.cpp src/Storages/DeltaMerge/BitmapFilter/BitmapFilter.cpp src/Storages/DeltaMerge/DMVersionFilterBlockInputStream.cpp diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index ecd95d3f89c..3f800db8174 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -1516,4 +1516,9 @@ Block JoinProbeHelper::genResultBlockForLeftOuterSemi(JoinProbeContext & ctx) return res_block; } +// SemiJoinProbe.cpp calls this function +template void DB::JoinProbeHelperUtil::flushInsertBatch( + DB::JoinProbeWorkerData & wd, + DB::MutableColumns & added_columns) const; + } // namespace DB From af9c1e0abf68e034870497e8f8aac0d7a92eabc1 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Sat, 7 Jun 2025 15:12:03 +0800 Subject: [PATCH 72/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp index bd0db517147..d167522b7b0 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp @@ -233,6 +233,8 @@ Block JoinProbeBuildScanner::scanImpl(JoinProbeWorkerData & wd) wd.selective_offsets.clear(); } + scan_size += rows - index; + if (index >= rows) { container = nullptr; From f960707aba5c86ef5ae88e43bc298ff935a8d3c7 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Sat, 7 Jun 2025 16:04:34 +0800 Subject: [PATCH 73/84] fix Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoin.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index 9c74a1d9de0..9b3e140b9b7 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -681,7 +681,7 @@ Block HashJoin::probeLastResultBlock(size_t stream_index) bool HashJoin::needProbeScanBuildSide() const { - return join_probe_build_scanner != nullptr; + return isRightOuterJoin(kind) || isRightSemiFamily(kind); } Block HashJoin::probeScanBuildSide(size_t stream_index) From 750a788a8e2bce07890ff375df9d8880b2a2d6e9 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Tue, 10 Jun 2025 00:17:52 +0800 Subject: [PATCH 74/84] consider a key is required only if it's needed for other condition in right semi/anti join Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoin.cpp | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index 9b3e140b9b7..8418d6d32ea 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -199,19 +199,27 @@ void HashJoin::initRowLayoutAndHashJoinMethod() /// Move all raw required join key column to the end of the join key. Names new_key_names_left, new_key_names_right; BoolVec raw_required_key_flag(keys_size); + bool is_right_semi_join = isRightSemiFamily(kind); + auto check_key_required = [&](const String & name) -> bool { + // If it's right semi/anti join, the key is required if and only if it's needed for other condition + if (is_right_semi_join) + return required_columns_names_set_for_other_condition.contains(name); + else + return right_sample_block_pruned.has(name); + }; for (size_t i = 0; i < keys_size; ++i) { bool is_raw_required = false; if (key_columns[i].column_ptr->valuesHaveFixedSize()) { - if (right_sample_block_pruned.has(key_names_right[i])) + if (check_key_required(key_names_right[i])) is_raw_required = true; } else { if (canAsColumnString(key_columns[i].column_ptr) && getStringCollatorKind(collators) == StringCollatorKind::StringBinary - && right_sample_block_pruned.has(key_names_right[i])) + && check_key_required(key_names_right[i])) { is_raw_required = true; } @@ -417,6 +425,9 @@ void HashJoin::initProbe(const Block & sample_block, size_t probe_concurrency_) active_probe_worker = probe_concurrency; probe_workers_data.resize(probe_concurrency); + if (needProbeScanBuildSide()) + join_probe_build_scanner = std::make_unique(this); + probe_initialized = true; } @@ -457,8 +468,6 @@ bool HashJoin::finishOneProbe(size_t stream_index) if (active_probe_worker.fetch_sub(1) == 1) { FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_mpp_hash_probe); - if (isRightOuterJoin(kind) || isRightSemiFamily(kind)) - join_probe_build_scanner = std::make_unique(this); wait_probe_finished_future->finish(); return true; } From 36401b78f9cb5eb1c4fd8d399ca94adb7b9be3f9 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Fri, 13 Jun 2025 04:23:15 +0800 Subject: [PATCH 75/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoin.cpp | 41 +++- dbms/src/Interpreters/JoinV2/HashJoin.h | 3 + .../src/Interpreters/JoinV2/HashJoinBuild.cpp | 227 ++++++++++-------- dbms/src/Interpreters/JoinV2/HashJoinBuild.h | 45 ++-- .../JoinV2/HashJoinProbeBuildScanner.cpp | 7 +- .../JoinV2/HashJoinProbeBuildScanner.h | 4 +- .../Interpreters/JoinV2/HashJoinRowLayout.h | 93 ++++++- 7 files changed, 280 insertions(+), 140 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index 8418d6d32ea..ee4c55202a5 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -329,7 +329,7 @@ void HashJoin::initBuild(const Block & sample_block, size_t build_concurrency_) build_workers_data.resize(build_concurrency); for (size_t i = 0; i < build_concurrency; ++i) build_workers_data[i].key_getter = createHashJoinKeyGetter(method, collators); - for (size_t i = 0; i < JOIN_BUILD_PARTITION_COUNT + 1; ++i) + for (size_t i = 0; i < JOIN_BUILD_PARTITION_COUNT; ++i) multi_row_containers.emplace_back(std::make_unique()); build_initialized = true; @@ -352,7 +352,28 @@ void HashJoin::initProbe(const Block & sample_block, size_t probe_concurrency_) } left_sample_block_pruned = left_sample_block; - removeUselessColumn(left_sample_block_pruned); + + const NameSet & probe_output_name_set = has_other_condition + ? output_columns_names_set_for_other_condition_after_finalize + : output_column_names_set_after_finalize; + for (size_t pos = 0; pos < left_sample_block_pruned.columns();) + { + if (!probe_output_name_set.contains(left_sample_block_pruned.getByPosition(pos).name)) + { + if (std::find( + key_names_left.begin(), + key_names_left.end(), + left_sample_block_pruned.getByPosition(pos).name) + == key_names_left.end()) + { + LOG_ERROR(log, "shit"); + } + left_sample_block_pruned.erase(pos); + } + else + ++pos; + } + //removeUselessColumn(left_sample_block_pruned); all_sample_block_pruned = left_sample_block_pruned.cloneEmpty(); size_t right_columns = right_sample_block_pruned.columns(); @@ -434,12 +455,19 @@ void HashJoin::initProbe(const Block & sample_block, size_t probe_concurrency_) bool HashJoin::finishOneBuildRow(size_t stream_index) { auto & wd = build_workers_data[stream_index]; + if (wd.non_joined_block.rows() > 0) + { + non_joined_blocks.insertNonFullBlock(std::move(wd.non_joined_block)); + wd.non_joined_block = {}; + } LOG_DEBUG( log, - "{} insert block to row containers cost {}ms, row count {}, padding size {}({:.2f}% of all size {})", + "{} insert block to row containers cost {}ms, row count {}, non-joined count {}, padding size {}({:.2f}% of " + "all size {})", stream_index, wd.build_time, wd.row_count, + wd.non_joined_row_count, wd.padding_size, 100.0 * wd.padding_size / wd.all_size, wd.all_size); @@ -582,15 +610,12 @@ void HashJoin::buildRowFromBlock(const Block & b, size_t stream_index) assertBlocksHaveEqualStructure(block, right_sample_block_pruned, "Join Build"); - insertBlockToRowContainers( - method, - kind, + JoinBuildHelper::insertBlockToRowContainers( + this, block, rows, key_columns, null_map, - row_layout, - multi_row_containers, build_workers_data[stream_index], shouldCheckLateMaterialization()); diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.h b/dbms/src/Interpreters/JoinV2/HashJoin.h index b9610c1dc75..4d82b55ef81 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.h +++ b/dbms/src/Interpreters/JoinV2/HashJoin.h @@ -99,6 +99,7 @@ class HashJoin } private: + friend JoinBuildHelper; friend JoinProbeHelper; friend SemiJoinProbeHelper; friend JoinProbeBuildScanner; @@ -158,6 +159,8 @@ class HashJoin /// Row containers std::vector> multi_row_containers; + /// Non-joined blocks + NonJoinedBlocks non_joined_blocks; /// Build row phase size_t build_concurrency = 0; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp b/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp index 1d72afcad78..0e28ca70a0b 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp @@ -13,10 +13,9 @@ // limitations under the License. #include +#include #include -#include "Interpreters/JoinUtils.h" -#include "Parsers/ASTTablesInSelectQuery.h" namespace DB { @@ -27,14 +26,12 @@ extern const int UNKNOWN_SET_DATA_VARIANT; template -void NO_INLINE insertBlockToRowContainersTypeImpl( - ASTTableJoin::Kind kind, +void NO_INLINE JoinBuildHelper::insertBlockToRowContainersImpl( + HashJoin * join, Block & block, size_t rows, const ColumnRawPtrs & key_columns, ConstNullMapPtr null_map, - const HashJoinRowLayout & row_layout, - std::vector> & multi_row_containers, JoinBuildWorkerData & wd, bool check_lm_row_size) { @@ -43,29 +40,45 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( using HashValueType = typename KeyGetter::HashValueType; static_assert(sizeof(HashValueType) <= sizeof(decltype(wd.hashes)::value_type)); + const auto kind = join->kind; + const auto & row_layout = join->row_layout; + const auto & settings = join->settings; + auto & multi_row_containers = join->multi_row_containers; + auto & non_joined_blocks = join->non_joined_blocks; + auto & key_getter = *static_cast(wd.key_getter.get()); key_getter.reset(key_columns, row_layout.raw_key_column_indexes.size()); + RUNTIME_CHECK(multi_row_containers.size() == JOIN_BUILD_PARTITION_COUNT); wd.row_sizes.clear(); bool is_right_semi_family = isRightSemiFamily(kind); if (is_right_semi_family) { wd.row_sizes.resize_fill(rows, row_layout.other_column_for_other_condition_fixed_size); - wd.right_semi_selector.resize(rows); + wd.right_semi_selector.clear(); + wd.right_semi_selector.reserve(rows); + if constexpr (has_null_map) + { + wd.right_semi_offsets.clear(); + wd.right_semi_offsets.reserve(rows); + } } else { wd.row_sizes.resize_fill(rows, row_layout.other_column_fixed_size); } + if constexpr (has_null_map && need_record_null_rows) + { + wd.non_joined_offsets.clear(); + wd.non_joined_offsets.reserve(rows); + } wd.hashes.resize(rows); - /// The last partition is used to hold rows with null join key. - constexpr size_t part_count = JOIN_BUILD_PARTITION_COUNT + 1; wd.partition_row_sizes.clear(); - wd.partition_row_sizes.resize_fill_zero(part_count); + wd.partition_row_sizes.resize_fill_zero(JOIN_BUILD_PARTITION_COUNT); wd.partition_row_count.clear(); - wd.partition_row_count.resize_fill_zero(part_count); + wd.partition_row_count.resize_fill_zero(JOIN_BUILD_PARTITION_COUNT); wd.partition_last_row_index.clear(); - wd.partition_last_row_index.resize_fill(part_count, -1); + wd.partition_last_row_index.resize_fill(JOIN_BUILD_PARTITION_COUNT, -1); if (check_lm_row_size) { @@ -94,27 +107,23 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( for (size_t i = 0; i < rows; ++i) { - if (has_null_map && (*null_map)[i]) + if constexpr (has_null_map) { - // TODO: the non-key row does not needed for right semi join. However, IColumn::scatterTo must need a selector for all rows - if (is_right_semi_family) - wd.right_semi_selector[i] = part_count - 1; - if constexpr (need_record_null_rows) + if ((*null_map)[i]) { - RUNTIME_CHECK_MSG(false, "need_record_null_rows is not supported yet"); - //TODO - //wd.row_sizes[i] += sizeof(RowPtr); - //wd.row_sizes[i] = alignRowSize(wd.row_sizes[i]); - //wd.partition_row_sizes[part_count - 1] += wd.row_sizes[i]; - //++wd.partition_row_count[part_count - 1]; + if constexpr (need_record_null_rows) + wd.non_joined_offsets.push_back(i); + continue; } - continue; + if (is_right_semi_family) + wd.right_semi_offsets.push_back(i); } + const auto & key = key_getter.getJoinKeyWithBuffer(i); wd.hashes[i] = static_cast(Hash()(key)); size_t part_num = getJoinBuildPartitionNum(wd.hashes[i]); if (is_right_semi_family) - wd.right_semi_selector[i] = part_num; + wd.right_semi_selector.push_back(part_num); size_t ptr_and_key_size = sizeof(RowPtr) + key_getter.getJoinKeyByteSize(key); if constexpr (KeyGetterType::joinKeyCompareHashFirst()) @@ -152,8 +161,9 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( ++wd.partition_row_count[part_num]; } - std::array partition_row_container; - for (size_t i = 0; i < part_count; ++i) + std::array partition_row_container; + size_t row_count = 0; + for (size_t i = 0; i < JOIN_BUILD_PARTITION_COUNT; ++i) { if (wd.partition_row_count[i] > 0) { @@ -169,13 +179,12 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( container.hashes.reserve(wd.partition_row_count[i]); wd.partition_row_sizes[i] = 0; - if (i != JOIN_BUILD_PARTITION_COUNT) - { - // Do not add the count of null rows - wd.row_count += wd.partition_row_count[i]; - } + row_count += wd.partition_row_count[i]; } } + RUNTIME_CHECK(row_count <= rows); + wd.row_count += row_count; + wd.non_joined_row_count += rows - row_count; constexpr size_t step = 256; wd.row_ptrs.reserve(step); @@ -189,14 +198,7 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( { if (has_null_map && (*null_map)[j]) { - if constexpr (need_record_null_rows) - { - //TODO - } - else - { - wd.row_ptrs.push_back(nullptr); - } + wd.row_ptrs.push_back(nullptr); continue; } size_t part_num = getJoinBuildPartitionNum(wd.hashes[j]); @@ -226,7 +228,7 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( for (size_t i = 0; i < other_column_count; ++i) { size_t index = row_layout.other_column_indexes[i].first; - if constexpr (has_null_map && !need_record_null_rows) + if constexpr (has_null_map) block.getByPosition(index).column->serializeToPos(wd.row_ptrs, start, end - start, true); else block.getByPosition(index).column->serializeToPos(wd.row_ptrs, start, end - start, false); @@ -235,19 +237,19 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( if (isRightSemiFamily(kind)) { - IColumn::ScatterColumns scatter_columns(part_count); + IColumn::ScatterColumns scatter_columns(JOIN_BUILD_PARTITION_COUNT); for (size_t i = row_layout.other_column_count_for_other_condition; i < row_layout.other_column_indexes.size(); ++i) { size_t index = row_layout.other_column_indexes[i].first; auto & column_data = block.getByPosition(index); size_t column_memory = column_data.column->byteSize(); - for (size_t j = 0; j < part_count; ++j) + for (size_t j = 0; j < JOIN_BUILD_PARTITION_COUNT; ++j) { auto new_column_data = column_data.cloneEmpty(); if (wd.partition_row_count[j] > 0) { - size_t memory_hint = column_memory * wd.partition_row_count[j] / rows + 16; + size_t memory_hint = 1.2 * column_memory * wd.partition_row_count[j] / rows; new_column_data.column->assumeMutable()->reserveWithTotalMemoryHint( wd.partition_row_count[j], memory_hint); @@ -255,96 +257,113 @@ void NO_INLINE insertBlockToRowContainersTypeImpl( scatter_columns[j] = new_column_data.column->assumeMutable(); partition_row_container[j].other_column_block.insert(std::move(new_column_data)); } - column_data.column->scatterTo(scatter_columns, wd.right_semi_selector); + if constexpr (has_null_map) + column_data.column->scatterTo(scatter_columns, wd.right_semi_selector, wd.right_semi_offsets); + else + column_data.column->scatterTo(scatter_columns, wd.right_semi_selector); } } - for (size_t i = 0; i < part_count; ++i) + if constexpr (has_null_map && need_record_null_rows) + { + if (!wd.non_joined_offsets.empty()) + { + join->initOutputBlock(wd.non_joined_block); + RUNTIME_CHECK(wd.non_joined_block.rows() < settings.max_block_size); + size_t columns = wd.non_joined_block.columns(); + auto fill_block = [&](size_t offset_start, size_t length) { + for (size_t i = 0; i < columns; ++i) + { + const auto & name = wd.non_joined_block.getByPosition(i).name; + auto & src_column = block.getByName(name).column; + wd.non_joined_block.getByPosition(i).column->assumeMutable()->insertSelectiveRangeFrom( + *src_column, + wd.non_joined_offsets, + offset_start, + length); + } + }; + size_t offset_start = 0; + size_t offset_size = wd.non_joined_offsets.size(); + while (true) + { + size_t remaining_size = settings.max_block_size - wd.non_joined_block.rows(); + if (remaining_size > offset_size - offset_start) + { + fill_block(offset_start, offset_size - offset_start); + break; + } + fill_block(offset_start, remaining_size); + offset_start += remaining_size; + non_joined_blocks.insertFullBlock(std::move(wd.non_joined_block)); + wd.non_joined_block = {}; + if (offset_start >= offset_size) + break; + join->initOutputBlock(wd.non_joined_block); + } + } + } + + for (size_t i = 0; i < JOIN_BUILD_PARTITION_COUNT; ++i) { if (wd.partition_row_count[i] > 0) multi_row_containers[i]->insert(std::move(partition_row_container[i]), wd.partition_row_count[i]); } } -template -void insertBlockToRowContainersType( - ASTTableJoin::Kind kind, +void JoinBuildHelper::insertBlockToRowContainers( + HashJoin * join, Block & block, size_t rows, const ColumnRawPtrs & key_columns, ConstNullMapPtr null_map, - const HashJoinRowLayout & row_layout, - std::vector> & multi_row_containers, - JoinBuildWorkerData & worker_data, + JoinBuildWorkerData & wd, bool check_lm_row_size) { -#define CALL(has_null_map, need_record_null_rows) \ - insertBlockToRowContainersTypeImpl( \ - kind, \ - block, \ - rows, \ - key_columns, \ - null_map, \ - row_layout, \ - multi_row_containers, \ - worker_data, \ - check_lm_row_size); - - bool need_record_null_rows = needRecordNotInsertRows(kind); - if (null_map) - { - if (need_record_null_rows) - { - CALL(true, true); - } - else - { - CALL(true, false); - } +#define CALL2(KeyGetter, has_null_map, need_record_null_rows) \ + { \ + insertBlockToRowContainersImpl( \ + join, \ + block, \ + rows, \ + key_columns, \ + null_map, \ + wd, \ + check_lm_row_size); \ } - else - { - CALL(false, false); + +#define CALL1(KeyGetter) \ + { \ + bool need_record_null_rows = needRecordNotInsertRows(join->kind); \ + if (null_map) \ + { \ + if (need_record_null_rows) \ + CALL2(KeyGetter, true, true) \ + else \ + CALL2(KeyGetter, true, false) \ + } \ + else \ + CALL2(KeyGetter, false, false) \ } -#undef CALL -} -void insertBlockToRowContainers( - HashJoinKeyMethod method, - ASTTableJoin::Kind kind, - Block & block, - size_t rows, - const ColumnRawPtrs & key_columns, - ConstNullMapPtr null_map, - const HashJoinRowLayout & row_layout, - std::vector> & multi_row_containers, - JoinBuildWorkerData & worker_data, - bool check_lm_row_size) -{ - switch (method) + switch (join->method) { #define M(METHOD) \ case HashJoinKeyMethod::METHOD: \ using KeyGetterType##METHOD = HashJoinKeyGetterForType; \ - insertBlockToRowContainersType( \ - kind, \ - block, \ - rows, \ - key_columns, \ - null_map, \ - row_layout, \ - multi_row_containers, \ - worker_data, \ - check_lm_row_size); \ + CALL1(KeyGetterType##METHOD) \ break; APPLY_FOR_HASH_JOIN_VARIANTS(M) #undef M default: throw Exception( - fmt::format("Unknown JOIN keys variant {}.", magic_enum::enum_name(method)), + fmt::format("Unknown JOIN keys variant {}.", magic_enum::enum_name(join->method)), ErrorCodes::UNKNOWN_SET_DATA_VARIANT); } + +#undef CALL1 +#undef CALL2 } } // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/HashJoinBuild.h b/dbms/src/Interpreters/JoinV2/HashJoinBuild.h index 294ef6522f2..0b81dc50820 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinBuild.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinBuild.h @@ -19,8 +19,6 @@ #include #include -#include "Parsers/ASTTablesInSelectQuery.h" - namespace DB { @@ -45,15 +43,17 @@ inline size_t getJoinBuildPartitionNum(HashValueType hash) struct alignas(CPU_CACHE_LINE_SIZE) JoinBuildWorkerData { std::unique_ptr> key_getter; - /// Count of non-null-key rows size_t row_count = 0; - - RowPtr null_rows_list_head = nullptr; + size_t non_joined_row_count = 0; PaddedPODArray row_sizes; PaddedPODArray hashes; RowPtrs row_ptrs; + IColumn::Selector right_semi_selector; + BlockSelective right_semi_offsets; + Block non_joined_block; + IColumn::Offsets non_joined_offsets; PaddedPODArray partition_row_sizes; PaddedPODArray partition_row_count; @@ -76,17 +76,30 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinBuildWorkerData size_t lm_row_count = 0; }; -void insertBlockToRowContainers( - HashJoinKeyMethod method, - ASTTableJoin::Kind kind, - Block & block, - size_t rows, - const ColumnRawPtrs & key_columns, - ConstNullMapPtr null_map, - const HashJoinRowLayout & row_layout, - std::vector> & multi_row_containers, - JoinBuildWorkerData & worker_data, - bool check_lm_row_size); +class HashJoin; +class JoinBuildHelper +{ +public: + static void insertBlockToRowContainers( + HashJoin * join, + Block & block, + size_t rows, + const ColumnRawPtrs & key_columns, + ConstNullMapPtr null_map, + JoinBuildWorkerData & wd, + bool check_lm_row_size); + +private: + template + static void NO_INLINE insertBlockToRowContainersImpl( + HashJoin * join, + Block & block, + size_t rows, + const ColumnRawPtrs & key_columns, + ConstNullMapPtr null_map, + JoinBuildWorkerData & wd, + bool check_lm_row_size); +}; } // namespace DB diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp index d167522b7b0..5ce1a6eab17 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp @@ -22,7 +22,7 @@ namespace DB using enum ASTTableJoin::Kind; -JoinProbeBuildScanner::JoinProbeBuildScanner(const HashJoin * join) +JoinProbeBuildScanner::JoinProbeBuildScanner(HashJoin * join) : join(join) { join_key_getter = createHashJoinKeyGetter(join->method, join->collators); @@ -136,11 +136,16 @@ Block JoinProbeBuildScanner::scanImpl(JoinProbeWorkerData & wd) const auto & multi_row_containers = join->multi_row_containers; const size_t max_block_size = join->settings.max_block_size; const size_t left_columns = join->left_sample_block_pruned.columns(); + auto & non_joined_blocks = join->non_joined_blocks; auto & key_getter = *static_cast(join_key_getter.get()); constexpr size_t key_offset = sizeof(RowPtr) + (KeyGetterType::joinKeyCompareHashFirst() ? sizeof(HashValueType) : 0); + Block * non_full_block = non_joined_blocks.getNextFullBlock(); + if (non_full_block != nullptr) + return *non_full_block; + size_t scan_size = 0; RowContainer * container = wd.current_container; size_t index = wd.current_container_index; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.h b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.h index 0a202d44433..dfc0947c16d 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.h @@ -23,7 +23,7 @@ class HashJoin; class JoinProbeBuildScanner { public: - explicit JoinProbeBuildScanner(const HashJoin * join); + explicit JoinProbeBuildScanner(HashJoin * join); Block scan(JoinProbeWorkerData & wd); @@ -56,7 +56,7 @@ class JoinProbeBuildScanner using FuncType = Block (JoinProbeBuildScanner::*)(JoinProbeWorkerData &); FuncType scan_func_ptr = nullptr; - const HashJoin * join; + HashJoin * join; std::mutex scan_build_lock; size_t scan_build_index = 0; /// Used for deserializing join key and getting required key offset diff --git a/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h b/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h index b529ecb9e78..1407bc6d2bc 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h @@ -133,7 +133,11 @@ struct RowContainer PaddedPODArray data; PaddedPODArray offsets; PaddedPODArray hashes; - /// Used for right semi/anti join + /// Only used for right semi/anti join + /// Stores the other columns that are not used for other conditions. + /// The schema corresponds to the entries in `HashJoinRowLayout::other_column_indexes` + /// after the first `other_column_count_for_other_condition` elements. + /// These indexes refer to columns in `HashJoin::right_sample_block_pruned`. Block other_column_block; size_t size() const { return offsets.size(); } @@ -142,15 +146,9 @@ struct RowContainer UInt64 getHash(ssize_t row) { return hashes[row]; } }; -struct alignas(CPU_CACHE_LINE_SIZE) MultipleRowContainer +class alignas(CPU_CACHE_LINE_SIZE) MultipleRowContainer { - std::mutex mu; - std::vector column_rows; - size_t all_row_count = 0; - - size_t build_table_index = 0; - size_t scan_table_index = 0; - +public: void insert(RowContainer && row_container, size_t count) { std::unique_lock lock(mu); @@ -160,19 +158,96 @@ struct alignas(CPU_CACHE_LINE_SIZE) MultipleRowContainer RowContainer * getNext() { + if (build_table_done.load(std::memory_order_relaxed)) + return nullptr; std::unique_lock lock(mu); if (build_table_index >= column_rows.size()) + { + build_table_done.store(true, std::memory_order_relaxed); return nullptr; + } return &column_rows[build_table_index++]; } RowContainer * getScanNext() { + if (scan_table_done.load(std::memory_order_relaxed)) + return nullptr; std::unique_lock lock(mu); if (scan_table_index >= column_rows.size()) + { + scan_table_done.store(true, std::memory_order_relaxed); return nullptr; + } return &column_rows[scan_table_index++]; } + +private: + std::mutex mu; + std::vector column_rows; + size_t all_row_count = 0; + + size_t build_table_index = 0; + size_t scan_table_index = 0; + + std::atomic_bool build_table_done = false; + std::atomic_bool scan_table_done = false; +}; + +class NonJoinedBlocks +{ +public: + void insertFullBlock(Block && block) + { + std::unique_lock lock(mu); + full_blocks.push_back(block); + } + + void insertNonFullBlock(Block && block) + { + std::unique_lock lock(mu); + non_full_blocks.push_back(block); + } + + Block * getNextFullBlock() + { + if (scan_full_blocks_done.load(std::memory_order_relaxed)) + return nullptr; + std::unique_lock lock(mu); + if (scan_full_blocks_index >= full_blocks.size()) + { + scan_full_blocks_done.store(true, std::memory_order_relaxed); + return nullptr; + } + return &full_blocks[scan_full_blocks_index++]; + } + + Block * getNextNonFullBlock() + { + if (scan_non_full_blocks_done.load(std::memory_order_relaxed)) + return nullptr; + std::unique_lock lock(mu); + if (scan_non_full_blocks_index >= non_full_blocks.size()) + { + scan_non_full_blocks_done.store(true, std::memory_order_relaxed); + return nullptr; + } + return &non_full_blocks[scan_non_full_blocks_index++]; + } + +private: + std::mutex mu; + /// Schema: HashJoin::output_block_after_finalize + /// Each block's size is equal to max_block_size + std::vector full_blocks; + /// Each block's size is less than max_block_size + std::vector non_full_blocks; + + size_t scan_full_blocks_index = 0; + size_t scan_non_full_blocks_index = 0; + + std::atomic_bool scan_full_blocks_done = false; + std::atomic_bool scan_non_full_blocks_done = false; }; } // namespace DB From 2d41fc020181c890c67cbdeebb939e5d8be0ee44 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Mon, 23 Jun 2025 23:05:55 +0800 Subject: [PATCH 76/84] fix bugs Signed-off-by: gengliqi --- dbms/src/Flash/tests/gtest_join_executor.cpp | 4 + .../src/Interpreters/JoinV2/HashJoinBuild.cpp | 20 ++-- .../JoinV2/HashJoinPointerTable.cpp | 2 +- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 14 ++- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 1 - .../JoinV2/HashJoinProbeBuildScanner.cpp | 93 +++++++++++++++++-- 6 files changed, 114 insertions(+), 20 deletions(-) diff --git a/dbms/src/Flash/tests/gtest_join_executor.cpp b/dbms/src/Flash/tests/gtest_join_executor.cpp index e0169501077..116739fcd15 100644 --- a/dbms/src/Flash/tests/gtest_join_executor.cpp +++ b/dbms/src/Flash/tests/gtest_join_executor.cpp @@ -3443,6 +3443,8 @@ try Field(static_cast(threshold))); ASSERT_COLUMNS_EQ_UR(ref, executeStreams(request, original_max_streams)) << "left_table_name = " << left_table_name << ", right_table_name = " << right_table_name; + if (cfg.enable_join_v2) + break; } WRAP_FOR_JOIN_TEST_END } @@ -3483,6 +3485,8 @@ try << "left_table_name = " << left_table_name << ", right_exchange_receiver_concurrency = " << exchange_concurrency << ", join_probe_cache_columns_threshold = " << threshold; + if (cfg.enable_join_v2) + break; } WRAP_FOR_JOIN_TEST_END } diff --git a/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp b/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp index 0e28ca70a0b..eeee5939850 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp @@ -235,7 +235,7 @@ void NO_INLINE JoinBuildHelper::insertBlockToRowContainersImpl( } } - if (isRightSemiFamily(kind)) + if (is_right_semi_family) { IColumn::ScatterColumns scatter_columns(JOIN_BUILD_PARTITION_COUNT); for (size_t i = row_layout.other_column_count_for_other_condition; i < row_layout.other_column_indexes.size(); @@ -274,13 +274,21 @@ void NO_INLINE JoinBuildHelper::insertBlockToRowContainersImpl( auto fill_block = [&](size_t offset_start, size_t length) { for (size_t i = 0; i < columns; ++i) { + auto des_mut_column = wd.non_joined_block.getByPosition(i).column->assumeMutable(); const auto & name = wd.non_joined_block.getByPosition(i).name; + if (!block.has(name)) + { + // If block does not have this column, this column should be nullable and from the left side + RUNTIME_CHECK_MSG( + des_mut_column->isColumnNullable(), + "Column with name {} is not nullable", + name); + auto & nullable_column = static_cast(*des_mut_column); + nullable_column.insertManyDefaults(length); + continue; + } auto & src_column = block.getByName(name).column; - wd.non_joined_block.getByPosition(i).column->assumeMutable()->insertSelectiveRangeFrom( - *src_column, - wd.non_joined_offsets, - offset_start, - length); + des_mut_column->insertSelectiveRangeFrom(*src_column, wd.non_joined_offsets, offset_start, length); } }; size_t offset_start = 0; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp b/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp index 985ec40791b..a58f3e69222 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinPointerTable.cpp @@ -149,7 +149,7 @@ bool HashJoinPointerTable::buildImpl( if (old_head != nullptr) unalignedStore(row_ptr, old_head); } - } while (build_size < max_build_size); + } while (build_size < max_build_size * 2); wd.build_pointer_table_size += build_size; wd.build_pointer_table_time += watch.elapsedMilliseconds(); return is_end; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 3f800db8174..5231025b73b 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -435,7 +435,15 @@ struct JoinProbeAdder RowPtr row_ptr, size_t ptr_offset) { - wd.right_join_row_ptrs.push_back(hasRowPtrMatchedFlag(row_ptr) ? nullptr : row_ptr); + if constexpr (has_other_condition) + { + wd.right_join_row_ptrs.push_back(hasRowPtrMatchedFlag(row_ptr) ? nullptr : row_ptr); + } + else + { + setRowPtrMatchedFlag(row_ptr); + } + ++current_offset; wd.selective_offsets.push_back(idx); helper.insertRowToBatch(wd, added_columns, row_ptr + ptr_offset); @@ -701,7 +709,9 @@ Block JoinProbeHelper::probeImpl(JoinProbeContext & ctx, JoinProbeWorkerData & w for (size_t i = 0; i < right_columns; ++i) wd.result_block.safeGetByPosition(left_columns + i).column = std::move(added_columns[i]); - if constexpr (kind == Inner || kind == LeftOuter || kind == Semi || kind == Anti) + if constexpr ( + kind == Inner || kind == LeftOuter || kind == RightOuter || kind == Semi || kind == Anti || kind == RightSemi + || kind == RightAnti) { if (wd.selective_offsets.empty()) return join->output_block_after_finalize; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 16f881af801..1959331ec8a 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -114,7 +114,6 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData size_t current_container_index = 0; /// Schema: HashJoin::output_block_after_finalize Block scan_result_block; - size_t current_scan_block_rows = 0; bool is_scan_end = false; /// Metrics diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp index 5ce1a6eab17..f5a8db5c259 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp @@ -128,6 +128,8 @@ Block JoinProbeBuildScanner::scan(JoinProbeWorkerData & wd) template Block JoinProbeBuildScanner::scanImpl(JoinProbeWorkerData & wd) { + static_assert(need_row_data || need_other_block_data); + if (wd.is_scan_end) return {}; @@ -142,20 +144,61 @@ Block JoinProbeBuildScanner::scanImpl(JoinProbeWorkerData & wd) constexpr size_t key_offset = sizeof(RowPtr) + (KeyGetterType::joinKeyCompareHashFirst() ? sizeof(HashValueType) : 0); - Block * non_full_block = non_joined_blocks.getNextFullBlock(); - if (non_full_block != nullptr) - return *non_full_block; + Block * full_block = non_joined_blocks.getNextFullBlock(); + if (full_block != nullptr) + return *full_block; size_t scan_size = 0; RowContainer * container = wd.current_container; size_t index = wd.current_container_index; - size_t scan_block_rows = wd.current_scan_block_rows; wd.selective_offsets.clear(); wd.selective_offsets.reserve(max_block_size); constexpr size_t insert_batch_max_size = 256; wd.insert_batch.clear(); wd.insert_batch.reserve(insert_batch_max_size); join->initOutputBlock(wd.scan_result_block); + + Block * non_joined_non_full_block = nullptr; + size_t output_columns = wd.scan_result_block.columns(); + while (true) + { + non_joined_non_full_block = non_joined_blocks.getNextNonFullBlock(); + if (non_joined_non_full_block == nullptr) + break; + RUNTIME_CHECK(non_joined_non_full_block->columns() == output_columns); + size_t rows = non_joined_non_full_block->rows(); + if (rows >= max_block_size / 2) + return *non_joined_non_full_block; + + if (wd.scan_result_block.rows() + rows > max_block_size) + { + Block res_block; + res_block.swap(wd.scan_result_block); + + join->initOutputBlock(wd.scan_result_block); + for (size_t i = 0; i < output_columns; ++i) + { + auto & src_column = non_joined_non_full_block->getByPosition(i); + auto & des_column = wd.scan_result_block.getByPosition(i); + des_column.column->assumeMutable()->insertRangeFrom(*src_column.column, 0, rows); + } + return res_block; + } + + for (size_t i = 0; i < output_columns; ++i) + { + auto & src_column = non_joined_non_full_block->getByPosition(i); + auto & des_column = wd.scan_result_block.getByPosition(i); + des_column.column->assumeMutable()->insertRangeFrom(*src_column.column, 0, rows); + } + } + + size_t scan_block_rows = wd.scan_result_block.rows(); + if constexpr (need_row_data) + scan_block_rows += wd.insert_batch.size(); + else + scan_block_rows += wd.selective_offsets.size(); + do { if (container == nullptr) @@ -255,14 +298,31 @@ Block JoinProbeBuildScanner::scanImpl(JoinProbeWorkerData & wd) wd.current_container = container; wd.current_container_index = index; - wd.current_scan_block_rows = scan_block_rows; - if unlikely (wd.is_scan_end || wd.scan_result_block.rows() >= max_block_size) + if constexpr (kind == RightOuter) + { + for (size_t i = 0; i < output_columns; ++i) + { + auto des_mut_column = wd.scan_result_block.getByPosition(i).column->assumeMutable(); + size_t current_rows = des_mut_column->size(); + if (current_rows < scan_block_rows) + { + // This column should be nullable and from the left side + RUNTIME_CHECK_MSG( + des_mut_column->isColumnNullable(), + "Column with name {} is not nullable", + wd.scan_result_block.getByPosition(i).name); + auto & nullable_column = static_cast(*des_mut_column); + nullable_column.insertManyDefaults(scan_block_rows - current_rows); + } + } + } + + if unlikely (wd.is_scan_end || scan_block_rows >= max_block_size) { - if (wd.scan_result_block.rows() == 0) + if (scan_block_rows == 0) return {}; - wd.current_scan_block_rows = 0; Block res_block; res_block.swap(wd.scan_result_block); return res_block; @@ -297,16 +357,29 @@ void JoinProbeBuildScanner::flushInsertBatch(JoinProbeWorkerData & wd) const other_column_count = join->row_layout.other_column_indexes.size(); else other_column_count = join->row_layout.other_column_count_for_other_condition; + + const size_t invalid_start_offset = other_column_count; + size_t advance_start_offset = invalid_start_offset; for (size_t i = 0; i < other_column_count; ++i) { size_t column_index = join->row_layout.other_column_indexes[i].first; auto output_index = join->output_column_indexes.at(left_columns + column_index); if (output_index < 0) { - join->right_sample_block_pruned.safeGetByPosition(column_index) - .column->deserializeAndAdvancePos(wd.insert_batch); + advance_start_offset = std::min(advance_start_offset, i); continue; } + if (advance_start_offset != invalid_start_offset) + { + while (advance_start_offset < i) + { + size_t column_index = join->row_layout.other_column_indexes[advance_start_offset].first; + join->right_sample_block_pruned.safeGetByPosition(column_index) + .column->deserializeAndAdvancePos(wd.insert_batch); + ++advance_start_offset; + } + advance_start_offset = invalid_start_offset; + } auto & des_column = wd.scan_result_block.safeGetByPosition(output_index); des_column.column->assumeMutable()->deserializeAndInsertFromPos(wd.insert_batch, true); if constexpr (last_flush) From 2f8380efcebf5196b1a720ec62e505dac9654f61 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Mon, 23 Jun 2025 23:32:18 +0800 Subject: [PATCH 77/84] u Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoin.cpp | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index ee4c55202a5..9231f51b78d 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -352,28 +352,7 @@ void HashJoin::initProbe(const Block & sample_block, size_t probe_concurrency_) } left_sample_block_pruned = left_sample_block; - - const NameSet & probe_output_name_set = has_other_condition - ? output_columns_names_set_for_other_condition_after_finalize - : output_column_names_set_after_finalize; - for (size_t pos = 0; pos < left_sample_block_pruned.columns();) - { - if (!probe_output_name_set.contains(left_sample_block_pruned.getByPosition(pos).name)) - { - if (std::find( - key_names_left.begin(), - key_names_left.end(), - left_sample_block_pruned.getByPosition(pos).name) - == key_names_left.end()) - { - LOG_ERROR(log, "shit"); - } - left_sample_block_pruned.erase(pos); - } - else - ++pos; - } - //removeUselessColumn(left_sample_block_pruned); + removeUselessColumn(left_sample_block_pruned); all_sample_block_pruned = left_sample_block_pruned.cloneEmpty(); size_t right_columns = right_sample_block_pruned.columns(); From 0d1e83da80c94e56db4f96ac8126b2972b642c95 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Wed, 30 Jul 2025 18:06:19 +0800 Subject: [PATCH 78/84] fix tests Signed-off-by: gengliqi --- .../JoinV2/gtest_semi_join_probe_list.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/dbms/src/Interpreters/tests/JoinV2/gtest_semi_join_probe_list.cpp b/dbms/src/Interpreters/tests/JoinV2/gtest_semi_join_probe_list.cpp index f9f48e56be9..326056a9e8e 100644 --- a/dbms/src/Interpreters/tests/JoinV2/gtest_semi_join_probe_list.cpp +++ b/dbms/src/Interpreters/tests/JoinV2/gtest_semi_join_probe_list.cpp @@ -37,7 +37,7 @@ try std::mt19937 g(rd()); std::uniform_int_distribution dist; - size_t n = dist(g) % 1000 + 1000; + size_t n = dist(g) % 100000 + 100000; list.reset(n); EXPECT_EQ(list.slotCapacity(), n); @@ -45,10 +45,18 @@ try std::unordered_set s1, s2; for (size_t i = 0; i < n; ++i) s1.insert(i); - while (!s1.empty() && !s2.empty()) + size_t active_slots = 0; + while (!s1.empty() || !s2.empty()) { - EXPECT_EQ(list.activeSlots(), s1.size()); - bool is_append = !s1.empty() && dist(g) % 2 == 0; + EXPECT_EQ(list.activeSlots(), active_slots); + + bool is_append; + if (s1.empty()) + is_append = false; + else if (s2.empty()) + is_append = true; + else + is_append = dist(g) % 2 == 0; if (is_append) { size_t append_idx = *s1.begin(); @@ -56,12 +64,14 @@ try s1.erase(append_idx); s2.insert(append_idx); list.append(append_idx); + ++active_slots; continue; } size_t remove_idx = *s2.begin(); EXPECT_TRUE(list.contains(remove_idx)); s2.erase(remove_idx); list.remove(remove_idx); + --active_slots; } EXPECT_EQ(list.slotCapacity(), n); EXPECT_EQ(list.activeSlots(), 0); From 8e75569a4d7ae4a4a0275f0df49c182afd73649b Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 7 Aug 2025 18:08:57 +0800 Subject: [PATCH 79/84] address comments Signed-off-by: gengliqi --- .../JoinV2/HashJoinProbeBuildScanner.cpp | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp index f5a8db5c259..81a9f5ac4f9 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp @@ -207,18 +207,16 @@ Block JoinProbeBuildScanner::scanImpl(JoinProbeWorkerData & wd) container = multi_row_containers[wd.current_scan_table_index]->getScanNext(); if (container == nullptr) { + std::unique_lock lock(scan_build_lock); + for (size_t i = 0; i < JOIN_BUILD_PARTITION_COUNT; ++i) { - std::unique_lock lock(scan_build_lock); - for (size_t i = 0; i < JOIN_BUILD_PARTITION_COUNT; ++i) + scan_build_index = (scan_build_index + i) % JOIN_BUILD_PARTITION_COUNT; + container = multi_row_containers[scan_build_index]->getScanNext(); + if (container != nullptr) { - scan_build_index = (scan_build_index + i) % JOIN_BUILD_PARTITION_COUNT; - container = multi_row_containers[scan_build_index]->getScanNext(); - if (container != nullptr) - { - wd.current_scan_table_index = scan_build_index; - scan_build_index = (scan_build_index + 1) % JOIN_BUILD_PARTITION_COUNT; - break; - } + wd.current_scan_table_index = scan_build_index; + scan_build_index = (scan_build_index + 1) % JOIN_BUILD_PARTITION_COUNT; + break; } } } @@ -229,6 +227,7 @@ Block JoinProbeBuildScanner::scanImpl(JoinProbeWorkerData & wd) } } size_t rows = container->size(); + size_t original_index = index; while (index < rows) { RowPtr ptr = container->getRowPtr(index); @@ -281,7 +280,7 @@ Block JoinProbeBuildScanner::scanImpl(JoinProbeWorkerData & wd) wd.selective_offsets.clear(); } - scan_size += rows - index; + scan_size += index - original_index; if (index >= rows) { From c13d2d38d7584c973141ff372828f07e6bffa096 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 7 Aug 2025 18:24:11 +0800 Subject: [PATCH 80/84] u Signed-off-by: gengliqi --- dbms/CMakeLists.txt | 2 +- dbms/src/Interpreters/JoinV2/HashJoin.cpp | 6 +++--- dbms/src/Interpreters/JoinV2/HashJoin.h | 8 ++++---- ...nner.cpp => HashJoinBuildScannerAfterProbe.cpp} | 14 +++++++------- ...dScanner.h => HashJoinBuildScannerAfterProbe.h} | 6 +++--- dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp | 2 +- 6 files changed, 19 insertions(+), 19 deletions(-) rename dbms/src/Interpreters/JoinV2/{HashJoinProbeBuildScanner.cpp => HashJoinBuildScannerAfterProbe.cpp} (96%) rename dbms/src/Interpreters/JoinV2/{HashJoinProbeBuildScanner.h => HashJoinBuildScannerAfterProbe.h} (91%) diff --git a/dbms/CMakeLists.txt b/dbms/CMakeLists.txt index 8b432defc38..ea7b32d65e2 100644 --- a/dbms/CMakeLists.txt +++ b/dbms/CMakeLists.txt @@ -114,7 +114,7 @@ check_then_add_sources_compile_flag ( src/Interpreters/JoinV2/HashJoinBuild.cpp src/Interpreters/JoinV2/HashJoinProbe.cpp src/Interpreters/JoinV2/SemiJoinProbe.cpp - src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp + src/Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.cpp src/IO/Compression/EncodingUtil.cpp src/Storages/DeltaMerge/BitmapFilter/BitmapFilter.cpp src/Storages/DeltaMerge/DMVersionFilterBlockInputStream.cpp diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index 9231f51b78d..8dfbfbeb05c 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -426,7 +426,7 @@ void HashJoin::initProbe(const Block & sample_block, size_t probe_concurrency_) probe_workers_data.resize(probe_concurrency); if (needProbeScanBuildSide()) - join_probe_build_scanner = std::make_unique(this); + join_build_scanner_after_probe = std::make_unique(this); probe_initialized = true; } @@ -697,12 +697,12 @@ bool HashJoin::needProbeScanBuildSide() const return isRightOuterJoin(kind) || isRightSemiFamily(kind); } -Block HashJoin::probeScanBuildSide(size_t stream_index) +Block HashJoin::scanBuildSideAfterProbe(size_t stream_index) { auto & wd = probe_workers_data[stream_index]; Stopwatch all_watch; SCOPE_EXIT({ probe_workers_data[stream_index].scan_build_side_time += all_watch.elapsedFromLastTime(); }); - return join_probe_build_scanner->scan(wd); + return join_build_scanner_after_probe->scan(wd); } void HashJoin::removeUselessColumn(Block & block) const diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.h b/dbms/src/Interpreters/JoinV2/HashJoin.h index 4d82b55ef81..5d65aed420c 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.h +++ b/dbms/src/Interpreters/JoinV2/HashJoin.h @@ -23,10 +23,10 @@ #include #include #include +#include #include #include #include -#include #include #include #include @@ -67,7 +67,7 @@ class HashJoin Block probeLastResultBlock(size_t stream_index); bool needProbeScanBuildSide() const; - Block probeScanBuildSide(size_t stream_index); + Block scanBuildSideAfterProbe(size_t stream_index); void removeUselessColumn(Block & block) const; /// Block's schema must be all_sample_block_pruned. @@ -102,7 +102,7 @@ class HashJoin friend JoinBuildHelper; friend JoinProbeHelper; friend SemiJoinProbeHelper; - friend JoinProbeBuildScanner; + friend JoinBuildScannerAfterProbe; static const DataTypePtr match_helper_type; @@ -177,7 +177,7 @@ class HashJoin std::unique_ptr join_probe_helper; std::unique_ptr semi_join_probe_helper; /// Probe scan build side - std::unique_ptr join_probe_build_scanner; + std::unique_ptr join_build_scanner_after_probe; const JoinProfileInfoPtr profile_info = std::make_shared(); diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp b/dbms/src/Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.cpp similarity index 96% rename from dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp rename to dbms/src/Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.cpp index 81a9f5ac4f9..e8c013dc125 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.cpp @@ -13,7 +13,7 @@ // limitations under the License. #include -#include +#include #include #include @@ -22,7 +22,7 @@ namespace DB using enum ASTTableJoin::Kind; -JoinProbeBuildScanner::JoinProbeBuildScanner(HashJoin * join) +JoinBuildScannerAfterProbe::JoinBuildScannerAfterProbe(HashJoin * join) : join(join) { join_key_getter = createHashJoinKeyGetter(join->method, join->collators); @@ -70,7 +70,7 @@ JoinProbeBuildScanner::JoinProbeBuildScanner(HashJoin * join) #define SET_FUNC_PTR(KeyGetter, JoinType, need_row_data, need_other_block_data) \ { \ - scan_func_ptr = &JoinProbeBuildScanner::scanImpl; \ + scan_func_ptr = &JoinBuildScannerAfterProbe::scanImpl; \ } #define CALL2(KeyGetter, JoinType) \ @@ -120,13 +120,13 @@ JoinProbeBuildScanner::JoinProbeBuildScanner(HashJoin * join) #undef SET_FUNC_PTR } -Block JoinProbeBuildScanner::scan(JoinProbeWorkerData & wd) +Block JoinBuildScannerAfterProbe::scan(JoinProbeWorkerData & wd) { return (this->*scan_func_ptr)(wd); } template -Block JoinProbeBuildScanner::scanImpl(JoinProbeWorkerData & wd) +Block JoinBuildScannerAfterProbe::scanImpl(JoinProbeWorkerData & wd) { static_assert(need_row_data || need_other_block_data); @@ -330,7 +330,7 @@ Block JoinProbeBuildScanner::scanImpl(JoinProbeWorkerData & wd) } template -void JoinProbeBuildScanner::flushInsertBatch(JoinProbeWorkerData & wd) const +void JoinBuildScannerAfterProbe::flushInsertBatch(JoinProbeWorkerData & wd) const { const size_t left_columns = join->left_sample_block_pruned.columns(); for (auto [column_index, is_nullable] : join->row_layout.raw_key_column_indexes) @@ -388,7 +388,7 @@ void JoinProbeBuildScanner::flushInsertBatch(JoinProbeWorkerData & wd) const wd.insert_batch.clear(); } -void JoinProbeBuildScanner::fillNullMapWithZero(JoinProbeWorkerData & wd) const +void JoinBuildScannerAfterProbe::fillNullMapWithZero(JoinProbeWorkerData & wd) const { size_t left_columns = join->left_sample_block_pruned.columns(); for (auto [column_index, is_nullable] : join->row_layout.raw_key_column_indexes) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.h b/dbms/src/Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.h similarity index 91% rename from dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.h rename to dbms/src/Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.h index dfc0947c16d..4c1d10e8058 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbeBuildScanner.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.h @@ -20,10 +20,10 @@ namespace DB { class HashJoin; -class JoinProbeBuildScanner +class JoinBuildScannerAfterProbe { public: - explicit JoinProbeBuildScanner(HashJoin * join); + explicit JoinBuildScannerAfterProbe(HashJoin * join); Block scan(JoinProbeWorkerData & wd); @@ -53,7 +53,7 @@ class JoinProbeBuildScanner void fillNullMapWithZero(JoinProbeWorkerData & wd) const; private: - using FuncType = Block (JoinProbeBuildScanner::*)(JoinProbeWorkerData &); + using FuncType = Block (JoinBuildScannerAfterProbe::*)(JoinProbeWorkerData &); FuncType scan_func_ptr = nullptr; HashJoin * join; diff --git a/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp b/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp index e3a0850a309..0c5689cac4d 100644 --- a/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp +++ b/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp @@ -75,7 +75,7 @@ OperatorStatus HashJoinV2ProbeTransformOp::onOutput(Block & block) } return OperatorStatus::WAIT_FOR_NOTIFY; case ProbeStatus::SCAN_BUILD_SIDE: - block = join_ptr->probeScanBuildSide(op_index); + block = join_ptr->scanBuildSideAfterProbe(op_index); scan_hash_map_rows += block.rows(); if unlikely (!block) status = ProbeStatus::FINISHED; From 3ee90e9f7485035fc89db75705b8b0c4fad6ace2 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 7 Aug 2025 18:54:25 +0800 Subject: [PATCH 81/84] format Signed-off-by: gengliqi --- .../Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.cpp index e8c013dc125..81501409d23 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.cpp @@ -68,9 +68,10 @@ JoinBuildScannerAfterProbe::JoinBuildScannerAfterProbe(HashJoin * join) RUNTIME_CHECK(need_row_data || need_other_block_data); } -#define SET_FUNC_PTR(KeyGetter, JoinType, need_row_data, need_other_block_data) \ - { \ - scan_func_ptr = &JoinBuildScannerAfterProbe::scanImpl; \ +#define SET_FUNC_PTR(KeyGetter, JoinType, need_row_data, need_other_block_data) \ + { \ + scan_func_ptr \ + = &JoinBuildScannerAfterProbe::scanImpl; \ } #define CALL2(KeyGetter, JoinType) \ From 89f6b373469f1bec6d0756b43a9e10c05f323702 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 7 Aug 2025 19:10:57 +0800 Subject: [PATCH 82/84] optimize Signed-off-by: gengliqi --- .../src/Interpreters/JoinV2/HashJoinProbe.cpp | 20 +++++++++++++++++-- .../Interpreters/JoinV2/HashJoinRowLayout.h | 4 ---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp index 5231025b73b..b9e9ae2d2d0 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.cpp @@ -839,10 +839,18 @@ void JoinProbeHelper::probeFillColumns(JoinProbeContext & ctx, JoinProbeWorkerDa const auto & key2 = key_getter.deserializeJoinKey(ptr + key_offset); bool key_is_equal = joinKeyIsEqual(key_getter, key, key2, hash, ptr); collision += !key_is_equal; + if constexpr (Adder::need_not_matched) + is_matched |= key_is_equal; if (key_is_equal) { - if constexpr (Adder::need_not_matched) - is_matched = true; + if constexpr ((kind == RightSemi || kind == RightAnti) && !has_other_condition) + { + if (hasRowPtrMatchedFlag(ptr)) + { + ptr = nullptr; + break; + } + } if constexpr (Adder::need_matched) { @@ -988,6 +996,14 @@ void JoinProbeHelper::probeFillColumnsPrefetch( collision += !key_is_equal; if constexpr (Adder::need_not_matched) state->is_matched |= key_is_equal; + if constexpr ((kind == RightSemi || kind == RightAnti) && !has_other_condition) + { + if (key_is_equal && hasRowPtrMatchedFlag(ptr)) + { + next_ptr = nullptr; + key_is_equal = false; + } + } if (key_is_equal) { if constexpr (Adder::need_matched) diff --git a/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h b/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h index 1407bc6d2bc..c1debe10676 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h @@ -87,8 +87,6 @@ inline bool hasRowPtrMatchedFlag(RowPtr ptr) inline void setRowPtrMatchedFlag(RowPtr ptr) { - if (hasRowPtrMatchedFlag(ptr)) - return; reinterpret_cast *>(ptr)->fetch_or(0x01, std::memory_order_relaxed); } @@ -99,8 +97,6 @@ inline bool hasRowPtrNullFlag(RowPtr ptr) inline void setRowPtrNullFlag(RowPtr ptr) { - if (hasRowPtrNullFlag(ptr)) - return; reinterpret_cast *>(ptr)->fetch_or(0x10, std::memory_order_relaxed); } From 68d3e81b99046dddf111054e8025c92ff9fcee81 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Fri, 8 Aug 2025 20:50:18 +0800 Subject: [PATCH 83/84] address comments Signed-off-by: gengliqi --- dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp b/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp index 0c5689cac4d..18a4633b080 100644 --- a/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp +++ b/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp @@ -93,7 +93,7 @@ OperatorStatus HashJoinV2ProbeTransformOp::transformImpl(Block & block) assert(probe_context.isAllFinished()); if likely (block) { - if (block.rows() == 0) + if unlikely (block.rows() == 0) return OperatorStatus::NEED_INPUT; probe_context.resetBlock(block); } From 7dc1e62ab5df8b0727f488d85dac19a3a7861c31 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 20 Nov 2025 23:16:09 +0800 Subject: [PATCH 84/84] address comments Signed-off-by: gengliqi --- dbms/src/Interpreters/JoinV2/HashJoin.cpp | 4 +- dbms/src/Interpreters/JoinV2/HashJoin.h | 2 +- .../src/Interpreters/JoinV2/HashJoinBuild.cpp | 8 +- dbms/src/Interpreters/JoinV2/HashJoinBuild.h | 2 +- .../JoinV2/HashJoinBuildScannerAfterProbe.cpp | 105 +++++++++--------- dbms/src/Interpreters/JoinV2/HashJoinProbe.h | 3 + .../Interpreters/JoinV2/HashJoinRowLayout.h | 8 +- .../Operators/HashJoinV2ProbeTransformOp.cpp | 3 +- 8 files changed, 68 insertions(+), 67 deletions(-) diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.cpp b/dbms/src/Interpreters/JoinV2/HashJoin.cpp index 8dfbfbeb05c..c499e7a658a 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoin.cpp @@ -425,7 +425,7 @@ void HashJoin::initProbe(const Block & sample_block, size_t probe_concurrency_) active_probe_worker = probe_concurrency; probe_workers_data.resize(probe_concurrency); - if (needProbeScanBuildSide()) + if (needScanBuildSideAfterProbe()) join_build_scanner_after_probe = std::make_unique(this); probe_initialized = true; @@ -692,7 +692,7 @@ Block HashJoin::probeLastResultBlock(size_t stream_index) return {}; } -bool HashJoin::needProbeScanBuildSide() const +bool HashJoin::needScanBuildSideAfterProbe() const { return isRightOuterJoin(kind) || isRightSemiFamily(kind); } diff --git a/dbms/src/Interpreters/JoinV2/HashJoin.h b/dbms/src/Interpreters/JoinV2/HashJoin.h index 5d65aed420c..426971571ca 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoin.h +++ b/dbms/src/Interpreters/JoinV2/HashJoin.h @@ -66,7 +66,7 @@ class HashJoin Block probeBlock(JoinProbeContext & ctx, size_t stream_index); Block probeLastResultBlock(size_t stream_index); - bool needProbeScanBuildSide() const; + bool needScanBuildSideAfterProbe() const; Block scanBuildSideAfterProbe(size_t stream_index); void removeUselessColumn(Block & block) const; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp b/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp index eeee5939850..34855920b22 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinBuild.cpp @@ -59,8 +59,8 @@ void NO_INLINE JoinBuildHelper::insertBlockToRowContainersImpl( wd.right_semi_selector.reserve(rows); if constexpr (has_null_map) { - wd.right_semi_offsets.clear(); - wd.right_semi_offsets.reserve(rows); + wd.right_semi_selective.clear(); + wd.right_semi_selective.reserve(rows); } } else @@ -116,7 +116,7 @@ void NO_INLINE JoinBuildHelper::insertBlockToRowContainersImpl( continue; } if (is_right_semi_family) - wd.right_semi_offsets.push_back(i); + wd.right_semi_selective.push_back(i); } const auto & key = key_getter.getJoinKeyWithBuffer(i); @@ -258,7 +258,7 @@ void NO_INLINE JoinBuildHelper::insertBlockToRowContainersImpl( partition_row_container[j].other_column_block.insert(std::move(new_column_data)); } if constexpr (has_null_map) - column_data.column->scatterTo(scatter_columns, wd.right_semi_selector, wd.right_semi_offsets); + column_data.column->scatterTo(scatter_columns, wd.right_semi_selector, wd.right_semi_selective); else column_data.column->scatterTo(scatter_columns, wd.right_semi_selector); } diff --git a/dbms/src/Interpreters/JoinV2/HashJoinBuild.h b/dbms/src/Interpreters/JoinV2/HashJoinBuild.h index 0b81dc50820..5b8e18ce7b3 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinBuild.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinBuild.h @@ -51,7 +51,7 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinBuildWorkerData RowPtrs row_ptrs; IColumn::Selector right_semi_selector; - BlockSelective right_semi_offsets; + BlockSelective right_semi_selective; Block non_joined_block; IColumn::Offsets non_joined_offsets; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.cpp b/dbms/src/Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.cpp index 81501409d23..7eb19b9ef42 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.cpp +++ b/dbms/src/Interpreters/JoinV2/HashJoinBuildScannerAfterProbe.cpp @@ -17,6 +17,8 @@ #include #include +#include "Core/Block.h" + namespace DB { @@ -53,16 +55,17 @@ JoinBuildScannerAfterProbe::JoinBuildScannerAfterProbe(HashJoin * join) if (need_row_data) break; } - for (auto [column_index, _] : join->row_layout.other_column_indexes) + for (size_t i = 0; i < join->row_layout.other_column_count_for_other_condition; ++i) { + size_t column_index = join->row_layout.other_column_indexes[i].first; auto output_index = join->output_column_indexes.at(left_columns + column_index); need_row_data |= output_index >= 0; if (need_row_data) break; } - need_other_block_data = (kind == RightSemi || kind == RightAnti) - && join->row_layout.other_column_indexes.size() > join->row_layout.other_column_count_for_other_condition; + need_other_block_data + = join->row_layout.other_column_indexes.size() > join->row_layout.other_column_count_for_other_condition; // The output data should not be empty RUNTIME_CHECK(need_row_data || need_other_block_data); @@ -136,69 +139,63 @@ Block JoinBuildScannerAfterProbe::scanImpl(JoinProbeWorkerData & wd) using KeyGetterType = typename KeyGetter::Type; using HashValueType = typename KeyGetter::HashValueType; - const auto & multi_row_containers = join->multi_row_containers; - const size_t max_block_size = join->settings.max_block_size; - const size_t left_columns = join->left_sample_block_pruned.columns(); - auto & non_joined_blocks = join->non_joined_blocks; - - auto & key_getter = *static_cast(join_key_getter.get()); - constexpr size_t key_offset - = sizeof(RowPtr) + (KeyGetterType::joinKeyCompareHashFirst() ? sizeof(HashValueType) : 0); + auto & non_joined_blocks = join->non_joined_blocks; Block * full_block = non_joined_blocks.getNextFullBlock(); if (full_block != nullptr) return *full_block; - size_t scan_size = 0; - RowContainer * container = wd.current_container; - size_t index = wd.current_container_index; - wd.selective_offsets.clear(); - wd.selective_offsets.reserve(max_block_size); - constexpr size_t insert_batch_max_size = 256; - wd.insert_batch.clear(); - wd.insert_batch.reserve(insert_batch_max_size); - join->initOutputBlock(wd.scan_result_block); - - Block * non_joined_non_full_block = nullptr; - size_t output_columns = wd.scan_result_block.columns(); + const size_t max_block_size = join->settings.max_block_size; while (true) { - non_joined_non_full_block = non_joined_blocks.getNextNonFullBlock(); + Block * non_joined_non_full_block = non_joined_blocks.getNextNonFullBlock(); if (non_joined_non_full_block == nullptr) break; - RUNTIME_CHECK(non_joined_non_full_block->columns() == output_columns); + RUNTIME_CHECK(non_joined_non_full_block->columns() == join->output_block_after_finalize.columns()); size_t rows = non_joined_non_full_block->rows(); if (rows >= max_block_size / 2) return *non_joined_non_full_block; - if (wd.scan_result_block.rows() + rows > max_block_size) + if (wd.non_joined_non_full_blocks_rows + rows > max_block_size) { - Block res_block; - res_block.swap(wd.scan_result_block); - - join->initOutputBlock(wd.scan_result_block); - for (size_t i = 0; i < output_columns; ++i) - { - auto & src_column = non_joined_non_full_block->getByPosition(i); - auto & des_column = wd.scan_result_block.getByPosition(i); - des_column.column->assumeMutable()->insertRangeFrom(*src_column.column, 0, rows); - } + Block res_block = vstackBlocks(std::move(wd.non_joined_non_full_blocks)); + wd.non_joined_non_full_blocks.clear(); + wd.non_joined_non_full_blocks.emplace_back(); + wd.non_joined_non_full_blocks.back().swap(*non_joined_non_full_block); + wd.non_joined_non_full_blocks_rows = rows; return res_block; } - for (size_t i = 0; i < output_columns; ++i) - { - auto & src_column = non_joined_non_full_block->getByPosition(i); - auto & des_column = wd.scan_result_block.getByPosition(i); - des_column.column->assumeMutable()->insertRangeFrom(*src_column.column, 0, rows); - } + wd.non_joined_non_full_blocks.emplace_back(); + wd.non_joined_non_full_blocks.back().swap(*non_joined_non_full_block); + wd.non_joined_non_full_blocks_rows += rows; + } + if (wd.non_joined_non_full_blocks_rows > 0) + { + Block res_block = vstackBlocks(std::move(wd.non_joined_non_full_blocks)); + wd.non_joined_non_full_blocks.clear(); + wd.non_joined_non_full_blocks_rows = 0; + return res_block; } - size_t scan_block_rows = wd.scan_result_block.rows(); - if constexpr (need_row_data) - scan_block_rows += wd.insert_batch.size(); - else - scan_block_rows += wd.selective_offsets.size(); + const auto & multi_row_containers = join->multi_row_containers; + const size_t left_columns = join->left_sample_block_pruned.columns(); + auto & key_getter = *static_cast(join_key_getter.get()); + constexpr size_t key_offset + = sizeof(RowPtr) + (KeyGetterType::joinKeyCompareHashFirst() ? sizeof(HashValueType) : 0); + + size_t scan_size = 0; + RowContainer * container = wd.current_container; + size_t index = wd.current_container_index; + wd.selective_offsets.clear(); + wd.selective_offsets.reserve(max_block_size); + constexpr size_t insert_batch_max_size = 256; + wd.insert_batch.clear(); + wd.insert_batch.reserve(insert_batch_max_size); + + join->initOutputBlock(wd.scan_result_block); + size_t output_columns = wd.scan_result_block.columns(); + size_t result_block_rows = wd.scan_result_block.rows(); do { @@ -251,8 +248,8 @@ Block JoinBuildScannerAfterProbe::scanImpl(JoinProbeWorkerData & wd) { wd.selective_offsets.push_back(index); } - ++scan_block_rows; - if unlikely (scan_block_rows >= max_block_size) + ++result_block_rows; + if unlikely (result_block_rows >= max_block_size) { ++index; break; @@ -289,7 +286,7 @@ Block JoinBuildScannerAfterProbe::scanImpl(JoinProbeWorkerData & wd) index = 0; } - if unlikely (scan_block_rows >= max_block_size) + if unlikely (result_block_rows >= max_block_size) break; } while (scan_size < 2 * max_block_size); @@ -305,7 +302,7 @@ Block JoinBuildScannerAfterProbe::scanImpl(JoinProbeWorkerData & wd) { auto des_mut_column = wd.scan_result_block.getByPosition(i).column->assumeMutable(); size_t current_rows = des_mut_column->size(); - if (current_rows < scan_block_rows) + if (current_rows < result_block_rows) { // This column should be nullable and from the left side RUNTIME_CHECK_MSG( @@ -313,14 +310,14 @@ Block JoinBuildScannerAfterProbe::scanImpl(JoinProbeWorkerData & wd) "Column with name {} is not nullable", wd.scan_result_block.getByPosition(i).name); auto & nullable_column = static_cast(*des_mut_column); - nullable_column.insertManyDefaults(scan_block_rows - current_rows); + nullable_column.insertManyDefaults(result_block_rows - current_rows); } } } - if unlikely (wd.is_scan_end || scan_block_rows >= max_block_size) + if unlikely (wd.is_scan_end || result_block_rows >= max_block_size) { - if (scan_block_rows == 0) + if (result_block_rows == 0) return {}; Block res_block; diff --git a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h index 1959331ec8a..289c13cde88 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinProbe.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinProbe.h @@ -114,6 +114,9 @@ struct alignas(CPU_CACHE_LINE_SIZE) JoinProbeWorkerData size_t current_container_index = 0; /// Schema: HashJoin::output_block_after_finalize Block scan_result_block; + /// Accumulate non-joined non-full blocks and output them once they approach the max block size + Blocks non_joined_non_full_blocks; + size_t non_joined_non_full_blocks_rows = 0; bool is_scan_end = false; /// Metrics diff --git a/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h b/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h index c1debe10676..7263ef6add4 100644 --- a/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h +++ b/dbms/src/Interpreters/JoinV2/HashJoinRowLayout.h @@ -207,12 +207,12 @@ class NonJoinedBlocks Block * getNextFullBlock() { - if (scan_full_blocks_done.load(std::memory_order_relaxed)) + if (scan_full_blocks_done.load(std::memory_order_acquire)) return nullptr; std::unique_lock lock(mu); if (scan_full_blocks_index >= full_blocks.size()) { - scan_full_blocks_done.store(true, std::memory_order_relaxed); + scan_full_blocks_done.store(true, std::memory_order_release); return nullptr; } return &full_blocks[scan_full_blocks_index++]; @@ -220,12 +220,12 @@ class NonJoinedBlocks Block * getNextNonFullBlock() { - if (scan_non_full_blocks_done.load(std::memory_order_relaxed)) + if (scan_non_full_blocks_done.load(std::memory_order_acquire)) return nullptr; std::unique_lock lock(mu); if (scan_non_full_blocks_index >= non_full_blocks.size()) { - scan_non_full_blocks_done.store(true, std::memory_order_relaxed); + scan_non_full_blocks_done.store(true, std::memory_order_release); return nullptr; } return &non_full_blocks[scan_non_full_blocks_index++]; diff --git a/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp b/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp index 18a4633b080..6c279ce5788 100644 --- a/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp +++ b/dbms/src/Operators/HashJoinV2ProbeTransformOp.cpp @@ -58,7 +58,8 @@ OperatorStatus HashJoinV2ProbeTransformOp::onOutput(Block & block) if unlikely (probe_context.isAllFinished()) { join_ptr->finishOneProbe(op_index); - status = join_ptr->needProbeScanBuildSide() ? ProbeStatus::WAIT_PROBE_FINISH : ProbeStatus::FINISHED; + status + = join_ptr->needScanBuildSideAfterProbe() ? ProbeStatus::WAIT_PROBE_FINISH : ProbeStatus::FINISHED; block = join_ptr->probeLastResultBlock(op_index); if (block) return OperatorStatus::HAS_OUTPUT;